ST-THOMAS-OF-AQUINAS commited on
Commit
8e32a8a
·
verified ·
1 Parent(s): 613a048

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import joblib
3
+ import torch
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from transformers import DistilBertTokenizerFast, DistilBertModel
7
+
8
+ # 🔹 Config
9
+ MODEL_DIR = "svm_models"
10
+
11
+ # 🔹 FastAPI app
12
+ app = FastAPI(title="Author Identification API")
13
+
14
+ # 🔹 Load tokenizer & BERT
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
17
+ bert_model = DistilBertModel.from_pretrained("distilbert-base-uncased").to(device)
18
+ bert_model.eval()
19
+
20
+ # 🔹 Load SVM models
21
+ author_svms = {}
22
+ for file in os.listdir(MODEL_DIR):
23
+ if file.endswith("_svm.pkl"):
24
+ author = file.replace("_svm.pkl", "")
25
+ clf = joblib.load(os.path.join(MODEL_DIR, file))
26
+ author_svms[author] = clf
27
+ print(f"✅ Loaded {len(author_svms)} author models")
28
+
29
+ # 🔹 Embedding function
30
+ def embed_text(text):
31
+ enc = tokenizer([text], return_tensors="pt", truncation=True, padding=True, max_length=256)
32
+ enc = {k: v.to(device) for k, v in enc.items()}
33
+ with torch.no_grad():
34
+ outputs = bert_model(**enc)
35
+ pooled = outputs.last_hidden_state[:, 0, :].cpu().numpy()
36
+ return pooled
37
+
38
+ # 🔹 Request schema
39
+ class InputText(BaseModel):
40
+ text: str
41
+
42
+ # 🔹 API route
43
+ @app.post("/predict")
44
+ def predict_author(input: InputText):
45
+ text = input.text
46
+ emb = embed_text(text)
47
+
48
+ scores = {}
49
+ for author, clf in author_svms.items():
50
+ pred = clf.predict(emb)[0]
51
+ score = clf.decision_function(emb)[0]
52
+ scores[author] = float(score) if pred == 1 else -9999
53
+
54
+ if all(s == -9999 for s in scores.values()):
55
+ return {"author": "Unknown", "score": None}
56
+
57
+ best_author = max(scores, key=scores.get)
58
+ return {"author": best_author, "score": round(scores[best_author], 4)}