File size: 2,767 Bytes
662dfff
8e32a8a
97f71f2
 
fa70caf
97f71f2
 
 
662dfff
a05692f
 
 
 
 
47eba64
d89415b
47eba64
8af630e
a05692f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.svm import SVC
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List

# ๐Ÿ”น Ensure HF cache is writable (before importing transformers)
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.makedirs("/tmp/hf_cache", exist_ok=True)

# ๐Ÿ”น Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ๐Ÿ”น Load tokenizer & BERT model
try:
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
    bert_model.eval()
except Exception as e:
    raise RuntimeError(f"Failed to load BERT model: {e}")

# ๐Ÿ”น Load SVM models
MODEL_DIR = "models"
MODEL_FILES = ["Dean of students_svm.pkl", "Registra_svm.pkl"]
author_svms = {}

for file in MODEL_FILES:
    path = os.path.join(MODEL_DIR, file)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Model file not found: {path}")
    author = file.replace("_svm.pkl", "")
    try:
        clf = joblib.load(path)
        author_svms[author] = clf
    except Exception as e:
        raise RuntimeError(f"Failed to load SVM model {file}: {e}")

print(f"โœ… Loaded {len(author_svms)} author models from {MODEL_DIR}")

# ๐Ÿ”น Text embedding function
def embed_text(text: str):
    enc = tokenizer(
        [text], return_tensors="pt", truncation=True, padding=True, max_length=256
    )
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        outputs = bert_model(**enc)
    pooled = outputs.last_hidden_state[:, 0, :].cpu().numpy()  # CLS token
    return pooled

# ๐Ÿ”น Prediction function
def predict_author(text: str):
    emb = embed_text(text)
    predictions = {}
    for author, clf in author_svms.items():
        try:
            predictions[author] = clf.predict(emb)[0]
        except Exception as e:
            predictions[author] = -1
            print(f"โš ๏ธ Prediction failed for {author}: {e}")

    accepted = [author for author, pred in predictions.items() if pred == 1]
    if len(accepted) == 1:
        return accepted[0]
    elif len(accepted) > 1:
        return accepted[0]  # pick first if multiple
    else:
        return "Unknown"

# ๐Ÿ”น FastAPI app
app = FastAPI(title="Document Verification API")

class TextInput(BaseModel):
    texts: List[str]

@app.post("/predict")
def predict(input_data: TextInput):
    results = []
    for txt in input_data.texts:
        author = predict_author(txt)
        results.append({"text": txt, "predicted_author": author})
    return {"results": results}

@app.get("/health")
def health_check():
    return {"status": "ok"}