ganeshkonapalli commited on
Commit
ccfc946
·
verified ·
1 Parent(s): 9bd7a84

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -81
train.py DELETED
@@ -1,81 +0,0 @@
1
- # app/train_api.py
2
-
3
- from fastapi import APIRouter, UploadFile, File
4
- import pandas as pd
5
- import torch
6
- from io import StringIO
7
- import os
8
- import joblib
9
-
10
- from app.config import (
11
- DEVICE, LABEL_COLUMNS, MODEL_SAVE_DIR,
12
- NUM_EPOCHS, LEARNING_RATE
13
- )
14
- from app.models import BertMultiOutputModel # Your model class
15
- from app.dataset import MultiLabelDataset # Your dataset class
16
- from app.train_utils import (
17
- initialize_criterions, train_model,
18
- evaluate_model, summarize_metrics,
19
- save_model
20
- )
21
- from transformers import BertTokenizer
22
- from torch.utils.data import DataLoader
23
- from sklearn.preprocessing import LabelEncoder
24
- from sklearn.model_selection import train_test_split
25
-
26
- router = APIRouter()
27
-
28
- @router.post("/train")
29
- async def train_model_api(file: UploadFile = File(...)):
30
- # Load CSV data
31
- contents = await file.read()
32
- df = pd.read_csv(StringIO(contents.decode("utf-8")))
33
-
34
- # Split train/val
35
- train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
36
-
37
- # Label encode each label column
38
- label_encoders = {}
39
- for col in LABEL_COLUMNS:
40
- le = LabelEncoder()
41
- train_df[col] = le.fit_transform(train_df[col].astype(str))
42
- val_df[col] = le.transform(val_df[col].astype(str)) # same encoder
43
- label_encoders[col] = le
44
- # Save encoders
45
- os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
46
- joblib.dump(label_encoders, os.path.join(MODEL_SAVE_DIR, "label_encoders.pkl"))
47
-
48
- # Tokenizer
49
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
50
-
51
- # Datasets and Loaders
52
- train_dataset = MultiLabelDataset(train_df, tokenizer)
53
- val_dataset = MultiLabelDataset(val_df, tokenizer)
54
- train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
55
- val_loader = DataLoader(val_dataset, batch_size=16)
56
-
57
- # Model initialization
58
- num_labels = [len(le.classes_) for le in label_encoders.values()]
59
- model = BertMultiOutputModel(num_labels).to(DEVICE)
60
-
61
- # Optimizer and Loss
62
- optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
63
- criterions = initialize_criterions(train_df, label_encoders)
64
-
65
- # Training Loop
66
- for epoch in range(NUM_EPOCHS):
67
- train_loss = train_model(model, train_loader, optimizer, criterions, epoch)
68
- print(f"Epoch {epoch+1} Loss: {train_loss:.4f}")
69
-
70
- # Evaluation
71
- metrics, truths, preds = evaluate_model(model, val_loader)
72
- summary_df = summarize_metrics(metrics)
73
-
74
- # Save model
75
- save_model(model, model_name="bert_multi_output", save_format="pth")
76
-
77
- # Return summary report
78
- return {
79
- "message": "Training complete",
80
- "metrics": summary_df.to_dict(orient="records")
81
- }