Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from peft import PeftModel | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| BASE_MODEL = "distilbert-base-uncased" | |
| LORA_MODEL_PATH = "mjpsm/coca-cola-contact-classifier" | |
| MAX_LENGTH = 128 | |
| id2label = {0: "not_relevant", 1: "relevant"} | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----------------------------- | |
| # Load model + tokenizer | |
| # ----------------------------- | |
| tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL_PATH) | |
| base_model = AutoModelForSequenceClassification.from_pretrained( | |
| BASE_MODEL, | |
| num_labels=2 | |
| ) | |
| model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH) | |
| model.to(device) | |
| model.eval() | |
| # ----------------------------- | |
| # FastAPI app | |
| # ----------------------------- | |
| app = FastAPI( | |
| title="Coca-Cola Contact Form Classifier", | |
| description="LoRA-based text classification API", | |
| version="1.0.0" | |
| ) | |
| # ----------------------------- | |
| # Request schema | |
| # ----------------------------- | |
| class PredictionRequest(BaseModel): | |
| text: str | |
| # ----------------------------- | |
| # Prediction endpoint | |
| # ----------------------------- | |
| def predict(request: PredictionRequest): | |
| inputs = tokenizer( | |
| request.text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=MAX_LENGTH | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=1) | |
| confidence, pred_id = torch.max(probs, dim=1) | |
| return { | |
| "prediction": id2label[pred_id.item()], | |
| "confidence": round(confidence.item(), 4) | |
| } | |
| # ----------------------------- | |
| # Health check | |
| # ----------------------------- | |
| def health(): | |
| return {"status": "ok"} | |