Spaces:
Build error
Build error
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import torch | |
| from app.model_utils import train_and_save_model, load_model, LABEL_COLUMNS | |
| app = FastAPI() | |
| model, tokenizer, label_encoders = load_model() | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MAX_LEN = 128 | |
| class PredictRequest(BaseModel): | |
| text: str | |
| class TrainRequest(BaseModel): | |
| csv_path: str | |
| def predict(req: PredictRequest): | |
| inputs = tokenizer(req.text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LEN).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(inputs['input_ids'], inputs['attention_mask']) | |
| predictions = {} | |
| for i, output in enumerate(outputs): | |
| pred = torch.argmax(output, dim=1).item() | |
| decoded = label_encoders[LABEL_COLUMNS[i]].inverse_transform([pred])[0] | |
| predictions[LABEL_COLUMNS[i]] = decoded | |
| return {"text": req.text, "predictions": predictions} | |
| def train_model(req: TrainRequest): | |
| train_and_save_model(req.csv_path) | |
| return {"message": "Model trained and saved successfully"} | |