from fastapi import FastAPI from pydantic import BaseModel import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification from peft import PeftModel # ----------------------------- # Config # ----------------------------- BASE_MODEL = "distilbert-base-uncased" LORA_MODEL_PATH = "mjpsm/coca-cola-contact-classifier" MAX_LENGTH = 128 id2label = {0: "not_relevant", 1: "relevant"} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------------------- # Load model + tokenizer # ----------------------------- tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL_PATH) base_model = AutoModelForSequenceClassification.from_pretrained( BASE_MODEL, num_labels=2 ) model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH) model.to(device) model.eval() # ----------------------------- # FastAPI app # ----------------------------- app = FastAPI( title="Coca-Cola Contact Form Classifier", description="LoRA-based text classification API", version="1.0.0" ) # ----------------------------- # Request schema # ----------------------------- class PredictionRequest(BaseModel): text: str # ----------------------------- # Prediction endpoint # ----------------------------- @app.post("/predict") def predict(request: PredictionRequest): inputs = tokenizer( request.text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH ).to(device) with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=1) confidence, pred_id = torch.max(probs, dim=1) return { "prediction": id2label[pred_id.item()], "confidence": round(confidence.item(), 4) } # ----------------------------- # Health check # ----------------------------- @app.get("/") def health(): return {"status": "ok"}