from fastapi import FastAPI from pydantic import BaseModel from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification import torch import torch.nn.functional as F # ========================= # INIT # ========================= app = FastAPI( title="Skill Classification API", description="Predicts skill from student check-ins", version="1.0" ) MODEL_PATH = "mjpsm/skill-classifier-BERT-v1" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("🔄 Loading model...") tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_PATH) model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH) model.to(device) model.eval() print("✅ Model loaded!") # ========================= # INPUT SCHEMA # ========================= class InputText(BaseModel): text: str # ========================= # ROOT # ========================= @app.get("/") def home(): return {"message": "Skill Classification API is running"} # ========================= # PREDICT # ========================= @app.post("/predict") def predict(input: InputText): text = input.text inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=128 ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=1) pred = torch.argmax(probs, dim=1).item() label = model.config.id2label[pred] confidence = probs[0][pred].item() return { "prediction": label, "confidence": confidence }