ganeshko / main.py
ganeshkonapalli's picture
Upload 5 files
b900e95 verified
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from app.model_utils import train_and_save_model, load_model, LABEL_COLUMNS
app = FastAPI()
model, tokenizer, label_encoders = load_model()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 128
class PredictRequest(BaseModel):
text: str
class TrainRequest(BaseModel):
csv_path: str
@app.post("/predict")
def predict(req: PredictRequest):
inputs = tokenizer(req.text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LEN).to(DEVICE)
with torch.no_grad():
outputs = model(inputs['input_ids'], inputs['attention_mask'])
predictions = {}
for i, output in enumerate(outputs):
pred = torch.argmax(output, dim=1).item()
decoded = label_encoders[LABEL_COLUMNS[i]].inverse_transform([pred])[0]
predictions[LABEL_COLUMNS[i]] = decoded
return {"text": req.text, "predictions": predictions}
@app.post("/train")
def train_model(req: TrainRequest):
train_and_save_model(req.csv_path)
return {"message": "Model trained and saved successfully"}