ST-THOMAS-OF-AQUINAS's picture
Update app.py
a05692f verified
import os
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.svm import SVC
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
# ๐Ÿ”น Ensure HF cache is writable (before importing transformers)
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.makedirs("/tmp/hf_cache", exist_ok=True)
# ๐Ÿ”น Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ๐Ÿ”น Load tokenizer & BERT model
try:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
bert_model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load BERT model: {e}")
# ๐Ÿ”น Load SVM models
MODEL_DIR = "models"
MODEL_FILES = ["Dean of students_svm.pkl", "Registra_svm.pkl"]
author_svms = {}
for file in MODEL_FILES:
path = os.path.join(MODEL_DIR, file)
if not os.path.exists(path):
raise FileNotFoundError(f"Model file not found: {path}")
author = file.replace("_svm.pkl", "")
try:
clf = joblib.load(path)
author_svms[author] = clf
except Exception as e:
raise RuntimeError(f"Failed to load SVM model {file}: {e}")
print(f"โœ… Loaded {len(author_svms)} author models from {MODEL_DIR}")
# ๐Ÿ”น Text embedding function
def embed_text(text: str):
enc = tokenizer(
[text], return_tensors="pt", truncation=True, padding=True, max_length=256
)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
outputs = bert_model(**enc)
pooled = outputs.last_hidden_state[:, 0, :].cpu().numpy() # CLS token
return pooled
# ๐Ÿ”น Prediction function
def predict_author(text: str):
emb = embed_text(text)
predictions = {}
for author, clf in author_svms.items():
try:
predictions[author] = clf.predict(emb)[0]
except Exception as e:
predictions[author] = -1
print(f"โš ๏ธ Prediction failed for {author}: {e}")
accepted = [author for author, pred in predictions.items() if pred == 1]
if len(accepted) == 1:
return accepted[0]
elif len(accepted) > 1:
return accepted[0] # pick first if multiple
else:
return "Unknown"
# ๐Ÿ”น FastAPI app
app = FastAPI(title="Document Verification API")
class TextInput(BaseModel):
texts: List[str]
@app.post("/predict")
def predict(input_data: TextInput):
results = []
for txt in input_data.texts:
author = predict_author(txt)
results.append({"text": txt, "predicted_author": author})
return {"results": results}
@app.get("/health")
def health_check():
return {"status": "ok"}