Spaces:
Sleeping
Sleeping
Delete train.py
Browse files
train.py
DELETED
|
@@ -1,81 +0,0 @@
|
|
| 1 |
-
# app/train_api.py
|
| 2 |
-
|
| 3 |
-
from fastapi import APIRouter, UploadFile, File
|
| 4 |
-
import pandas as pd
|
| 5 |
-
import torch
|
| 6 |
-
from io import StringIO
|
| 7 |
-
import os
|
| 8 |
-
import joblib
|
| 9 |
-
|
| 10 |
-
from app.config import (
|
| 11 |
-
DEVICE, LABEL_COLUMNS, MODEL_SAVE_DIR,
|
| 12 |
-
NUM_EPOCHS, LEARNING_RATE
|
| 13 |
-
)
|
| 14 |
-
from app.models import BertMultiOutputModel # Your model class
|
| 15 |
-
from app.dataset import MultiLabelDataset # Your dataset class
|
| 16 |
-
from app.train_utils import (
|
| 17 |
-
initialize_criterions, train_model,
|
| 18 |
-
evaluate_model, summarize_metrics,
|
| 19 |
-
save_model
|
| 20 |
-
)
|
| 21 |
-
from transformers import BertTokenizer
|
| 22 |
-
from torch.utils.data import DataLoader
|
| 23 |
-
from sklearn.preprocessing import LabelEncoder
|
| 24 |
-
from sklearn.model_selection import train_test_split
|
| 25 |
-
|
| 26 |
-
router = APIRouter()
|
| 27 |
-
|
| 28 |
-
@router.post("/train")
|
| 29 |
-
async def train_model_api(file: UploadFile = File(...)):
|
| 30 |
-
# Load CSV data
|
| 31 |
-
contents = await file.read()
|
| 32 |
-
df = pd.read_csv(StringIO(contents.decode("utf-8")))
|
| 33 |
-
|
| 34 |
-
# Split train/val
|
| 35 |
-
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
|
| 36 |
-
|
| 37 |
-
# Label encode each label column
|
| 38 |
-
label_encoders = {}
|
| 39 |
-
for col in LABEL_COLUMNS:
|
| 40 |
-
le = LabelEncoder()
|
| 41 |
-
train_df[col] = le.fit_transform(train_df[col].astype(str))
|
| 42 |
-
val_df[col] = le.transform(val_df[col].astype(str)) # same encoder
|
| 43 |
-
label_encoders[col] = le
|
| 44 |
-
# Save encoders
|
| 45 |
-
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
| 46 |
-
joblib.dump(label_encoders, os.path.join(MODEL_SAVE_DIR, "label_encoders.pkl"))
|
| 47 |
-
|
| 48 |
-
# Tokenizer
|
| 49 |
-
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 50 |
-
|
| 51 |
-
# Datasets and Loaders
|
| 52 |
-
train_dataset = MultiLabelDataset(train_df, tokenizer)
|
| 53 |
-
val_dataset = MultiLabelDataset(val_df, tokenizer)
|
| 54 |
-
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
|
| 55 |
-
val_loader = DataLoader(val_dataset, batch_size=16)
|
| 56 |
-
|
| 57 |
-
# Model initialization
|
| 58 |
-
num_labels = [len(le.classes_) for le in label_encoders.values()]
|
| 59 |
-
model = BertMultiOutputModel(num_labels).to(DEVICE)
|
| 60 |
-
|
| 61 |
-
# Optimizer and Loss
|
| 62 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
| 63 |
-
criterions = initialize_criterions(train_df, label_encoders)
|
| 64 |
-
|
| 65 |
-
# Training Loop
|
| 66 |
-
for epoch in range(NUM_EPOCHS):
|
| 67 |
-
train_loss = train_model(model, train_loader, optimizer, criterions, epoch)
|
| 68 |
-
print(f"Epoch {epoch+1} Loss: {train_loss:.4f}")
|
| 69 |
-
|
| 70 |
-
# Evaluation
|
| 71 |
-
metrics, truths, preds = evaluate_model(model, val_loader)
|
| 72 |
-
summary_df = summarize_metrics(metrics)
|
| 73 |
-
|
| 74 |
-
# Save model
|
| 75 |
-
save_model(model, model_name="bert_multi_output", save_format="pth")
|
| 76 |
-
|
| 77 |
-
# Return summary report
|
| 78 |
-
return {
|
| 79 |
-
"message": "Training complete",
|
| 80 |
-
"metrics": summary_df.to_dict(orient="records")
|
| 81 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|