Spaces:
Runtime error
Runtime error
File size: 2,767 Bytes
662dfff 8e32a8a 97f71f2 fa70caf 97f71f2 662dfff a05692f 47eba64 d89415b 47eba64 8af630e a05692f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.svm import SVC
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
# ๐น Ensure HF cache is writable (before importing transformers)
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.makedirs("/tmp/hf_cache", exist_ok=True)
# ๐น Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ๐น Load tokenizer & BERT model
try:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
bert_model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load BERT model: {e}")
# ๐น Load SVM models
MODEL_DIR = "models"
MODEL_FILES = ["Dean of students_svm.pkl", "Registra_svm.pkl"]
author_svms = {}
for file in MODEL_FILES:
path = os.path.join(MODEL_DIR, file)
if not os.path.exists(path):
raise FileNotFoundError(f"Model file not found: {path}")
author = file.replace("_svm.pkl", "")
try:
clf = joblib.load(path)
author_svms[author] = clf
except Exception as e:
raise RuntimeError(f"Failed to load SVM model {file}: {e}")
print(f"โ
Loaded {len(author_svms)} author models from {MODEL_DIR}")
# ๐น Text embedding function
def embed_text(text: str):
enc = tokenizer(
[text], return_tensors="pt", truncation=True, padding=True, max_length=256
)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
outputs = bert_model(**enc)
pooled = outputs.last_hidden_state[:, 0, :].cpu().numpy() # CLS token
return pooled
# ๐น Prediction function
def predict_author(text: str):
emb = embed_text(text)
predictions = {}
for author, clf in author_svms.items():
try:
predictions[author] = clf.predict(emb)[0]
except Exception as e:
predictions[author] = -1
print(f"โ ๏ธ Prediction failed for {author}: {e}")
accepted = [author for author, pred in predictions.items() if pred == 1]
if len(accepted) == 1:
return accepted[0]
elif len(accepted) > 1:
return accepted[0] # pick first if multiple
else:
return "Unknown"
# ๐น FastAPI app
app = FastAPI(title="Document Verification API")
class TextInput(BaseModel):
texts: List[str]
@app.post("/predict")
def predict(input_data: TextInput):
results = []
for txt in input_data.texts:
author = predict_author(txt)
results.append({"text": txt, "predicted_author": author})
return {"results": results}
@app.get("/health")
def health_check():
return {"status": "ok"}
|