import os import torch from transformers import AutoTokenizer, AutoModel from sklearn.svm import SVC import joblib from fastapi import FastAPI from pydantic import BaseModel from typing import List # 🔹 Ensure HF cache is writable (before importing transformers) os.environ["HF_HOME"] = "/tmp/hf_cache" os.environ["TOKENIZERS_PARALLELISM"] = "false" os.makedirs("/tmp/hf_cache", exist_ok=True) # 🔹 Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 🔹 Load tokenizer & BERT model try: tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device) bert_model.eval() except Exception as e: raise RuntimeError(f"Failed to load BERT model: {e}") # 🔹 Load SVM models MODEL_DIR = "models" MODEL_FILES = ["Dean of students_svm.pkl", "Registra_svm.pkl"] author_svms = {} for file in MODEL_FILES: path = os.path.join(MODEL_DIR, file) if not os.path.exists(path): raise FileNotFoundError(f"Model file not found: {path}") author = file.replace("_svm.pkl", "") try: clf = joblib.load(path) author_svms[author] = clf except Exception as e: raise RuntimeError(f"Failed to load SVM model {file}: {e}") print(f"✅ Loaded {len(author_svms)} author models from {MODEL_DIR}") # 🔹 Text embedding function def embed_text(text: str): enc = tokenizer( [text], return_tensors="pt", truncation=True, padding=True, max_length=256 ) enc = {k: v.to(device) for k, v in enc.items()} with torch.no_grad(): outputs = bert_model(**enc) pooled = outputs.last_hidden_state[:, 0, :].cpu().numpy() # CLS token return pooled # 🔹 Prediction function def predict_author(text: str): emb = embed_text(text) predictions = {} for author, clf in author_svms.items(): try: predictions[author] = clf.predict(emb)[0] except Exception as e: predictions[author] = -1 print(f"⚠️ Prediction failed for {author}: {e}") accepted = [author for author, pred in predictions.items() if pred == 1] if len(accepted) == 1: return accepted[0] elif len(accepted) > 1: return accepted[0] # pick first if multiple else: return "Unknown" # 🔹 FastAPI app app = FastAPI(title="Document Verification API") class TextInput(BaseModel): texts: List[str] @app.post("/predict") def predict(input_data: TextInput): results = [] for txt in input_data.texts: author = predict_author(txt) results.append({"text": txt, "predicted_author": author}) return {"results": results} @app.get("/health") def health_check(): return {"status": "ok"}