Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.svm import SVC | |
| import joblib | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import List | |
| # ๐น Ensure HF cache is writable (before importing transformers) | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.makedirs("/tmp/hf_cache", exist_ok=True) | |
| # ๐น Device setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ๐น Load tokenizer & BERT model | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device) | |
| bert_model.eval() | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load BERT model: {e}") | |
| # ๐น Load SVM models | |
| MODEL_DIR = "models" | |
| MODEL_FILES = ["Dean of students_svm.pkl", "Registra_svm.pkl"] | |
| author_svms = {} | |
| for file in MODEL_FILES: | |
| path = os.path.join(MODEL_DIR, file) | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Model file not found: {path}") | |
| author = file.replace("_svm.pkl", "") | |
| try: | |
| clf = joblib.load(path) | |
| author_svms[author] = clf | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load SVM model {file}: {e}") | |
| print(f"โ Loaded {len(author_svms)} author models from {MODEL_DIR}") | |
| # ๐น Text embedding function | |
| def embed_text(text: str): | |
| enc = tokenizer( | |
| [text], return_tensors="pt", truncation=True, padding=True, max_length=256 | |
| ) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| outputs = bert_model(**enc) | |
| pooled = outputs.last_hidden_state[:, 0, :].cpu().numpy() # CLS token | |
| return pooled | |
| # ๐น Prediction function | |
| def predict_author(text: str): | |
| emb = embed_text(text) | |
| predictions = {} | |
| for author, clf in author_svms.items(): | |
| try: | |
| predictions[author] = clf.predict(emb)[0] | |
| except Exception as e: | |
| predictions[author] = -1 | |
| print(f"โ ๏ธ Prediction failed for {author}: {e}") | |
| accepted = [author for author, pred in predictions.items() if pred == 1] | |
| if len(accepted) == 1: | |
| return accepted[0] | |
| elif len(accepted) > 1: | |
| return accepted[0] # pick first if multiple | |
| else: | |
| return "Unknown" | |
| # ๐น FastAPI app | |
| app = FastAPI(title="Document Verification API") | |
| class TextInput(BaseModel): | |
| texts: List[str] | |
| def predict(input_data: TextInput): | |
| results = [] | |
| for txt in input_data.texts: | |
| author = predict_author(txt) | |
| results.append({"text": txt, "predicted_author": author}) | |
| return {"results": results} | |
| def health_check(): | |
| return {"status": "ok"} | |