ST-THOMAS-OF-AQUINAS commited on
Commit
fa70caf
ยท
verified ยท
1 Parent(s): 8e32a8a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -46
app.py CHANGED
@@ -1,58 +1,54 @@
1
- import os
2
- import joblib
3
  import torch
4
- from fastapi import FastAPI
5
- from pydantic import BaseModel
6
- from transformers import DistilBertTokenizerFast, DistilBertModel
7
 
8
- # ๐Ÿ”น Config
9
- MODEL_DIR = "svm_models"
 
 
10
 
11
- # ๐Ÿ”น FastAPI app
12
- app = FastAPI(title="Author Identification API")
 
13
 
14
- # ๐Ÿ”น Load tokenizer & BERT
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
17
- bert_model = DistilBertModel.from_pretrained("distilbert-base-uncased").to(device)
18
- bert_model.eval()
19
-
20
- # ๐Ÿ”น Load SVM models
21
- author_svms = {}
22
- for file in os.listdir(MODEL_DIR):
23
- if file.endswith("_svm.pkl"):
24
- author = file.replace("_svm.pkl", "")
25
- clf = joblib.load(os.path.join(MODEL_DIR, file))
26
- author_svms[author] = clf
27
- print(f"โœ… Loaded {len(author_svms)} author models")
28
-
29
- # ๐Ÿ”น Embedding function
30
- def embed_text(text):
31
  enc = tokenizer([text], return_tensors="pt", truncation=True, padding=True, max_length=256)
32
  enc = {k: v.to(device) for k, v in enc.items()}
33
  with torch.no_grad():
34
  outputs = bert_model(**enc)
35
- pooled = outputs.last_hidden_state[:, 0, :].cpu().numpy()
36
- return pooled
37
-
38
- # ๐Ÿ”น Request schema
39
- class InputText(BaseModel):
40
- text: str
41
-
42
- # ๐Ÿ”น API route
43
- @app.post("/predict")
44
- def predict_author(input: InputText):
45
- text = input.text
46
- emb = embed_text(text)
47
-
48
- scores = {}
49
  for author, clf in author_svms.items():
50
  pred = clf.predict(emb)[0]
51
  score = clf.decision_function(emb)[0]
52
- scores[author] = float(score) if pred == 1 else -9999
53
-
54
- if all(s == -9999 for s in scores.values()):
55
- return {"author": "Unknown", "score": None}
56
-
57
- best_author = max(scores, key=scores.get)
58
- return {"author": best_author, "score": round(scores[best_author], 4)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import AutoTokenizer, AutoModel
3
  import torch
4
+ import joblib
5
+
6
+ app = Flask(__name__)
7
 
8
+ # ๐Ÿ”น Load model + tokenizer from Hugging Face Hub
9
+ MODEL_NAME = "ST-THOMAS-OF-AQUINAS/Document_verification"
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+ bert_model = AutoModel.from_pretrained(MODEL_NAME)
12
 
13
+ # ๐Ÿ”น Load saved SVM classifiers (from your training step)
14
+ author_svms = joblib.load("author_svms.pkl") # saved dict of {author: svm_model}
15
+ label_map = joblib.load("label_map.pkl")
16
 
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ bert_model.to(device)
19
+
20
+ def predict_author(text):
21
+ bert_model.eval()
 
 
 
 
 
 
 
 
 
 
 
22
  enc = tokenizer([text], return_tensors="pt", truncation=True, padding=True, max_length=256)
23
  enc = {k: v.to(device) for k, v in enc.items()}
24
  with torch.no_grad():
25
  outputs = bert_model(**enc)
26
+ emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
27
+
28
+ predictions = {}
 
 
 
 
 
 
 
 
 
 
 
29
  for author, clf in author_svms.items():
30
  pred = clf.predict(emb)[0]
31
  score = clf.decision_function(emb)[0]
32
+ predictions[author] = (pred, score)
33
+
34
+ accepted = {a: s for a, (p, s) in predictions.items() if p == 1}
35
+ if not accepted:
36
+ return "Unknown", None
37
+ best_author = max(accepted, key=accepted.get)
38
+ return best_author, accepted[best_author]
39
+
40
+ @app.route("/predict", methods=["POST"])
41
+ def predict():
42
+ data = request.json
43
+ text = data.get("text", "")
44
+ if not text:
45
+ return jsonify({"error": "No text provided"}), 400
46
+
47
+ author, score = predict_author(text)
48
+ return jsonify({
49
+ "author": author,
50
+ "score": score if score is not None else 0
51
+ })
52
+
53
+ if __name__ == "__main__":
54
+ app.run(host="0.0.0.0", port=5000)