from fastapi import FastAPI from pydantic import BaseModel import torch from transformers import BertTokenizer from custom_model import MyBERTClassifier app = FastAPI() # The model is loaded from the repository repo_id = "jmt-r/predict-ai-abstract" tokenizer = BertTokenizer.from_pretrained(repo_id) model = MyBERTClassifier.from_pretrained(repo_id) model.eval() class PatentRequest(BaseModel): text: str @app.post("/predict") def predict(request: PatentRequest): inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): logits = model(**inputs) probs = torch.softmax(logits, dim=1) ai_prob = probs[0][1].item() # Logic for YES AI (1) or NO AI (0) prediction = 1 if ai_prob > 0.783 else 0 return {"prediction": prediction, "label": "YES AI" if prediction == 1 else "NO AI", "probability": ai_prob}