ST-THOMAS-OF-AQUINAS commited on
Commit
97f71f2
ยท
verified ยท
1 Parent(s): dfd0bbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -43
app.py CHANGED
@@ -1,54 +1,78 @@
1
- from flask import Flask, request, jsonify
2
- from transformers import AutoTokenizer, AutoModel
3
  import torch
 
 
4
  import joblib
 
 
 
 
 
 
5
 
6
- app = Flask(__name__)
 
 
7
 
8
- # ๐Ÿ”น Load model + tokenizer from Hugging Face Hub
9
- MODEL_NAME = "ST-THOMAS-OF-AQUINAS/Document_verification"
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
- bert_model = AutoModel.from_pretrained(MODEL_NAME)
 
12
 
13
- # ๐Ÿ”น Load saved SVM classifiers (from your training step)
14
- author_svms = joblib.load("author_svms.pkl") # saved dict of {author: svm_model}
15
- label_map = joblib.load("label_map.pkl")
 
 
 
 
 
16
 
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- bert_model.to(device)
 
 
 
 
 
 
19
 
20
- def predict_author(text):
21
- bert_model.eval()
22
  enc = tokenizer([text], return_tensors="pt", truncation=True, padding=True, max_length=256)
23
  enc = {k: v.to(device) for k, v in enc.items()}
24
  with torch.no_grad():
25
  outputs = bert_model(**enc)
26
- emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
27
-
28
- predictions = {}
29
- for author, clf in author_svms.items():
30
- pred = clf.predict(emb)[0]
31
- score = clf.decision_function(emb)[0]
32
- predictions[author] = (pred, score)
33
-
34
- accepted = {a: s for a, (p, s) in predictions.items() if p == 1}
35
- if not accepted:
36
- return "Unknown", None
37
- best_author = max(accepted, key=accepted.get)
38
- return best_author, accepted[best_author]
39
-
40
- @app.route("/predict", methods=["POST"])
41
- def predict():
42
- data = request.json
43
- text = data.get("text", "")
44
- if not text:
45
- return jsonify({"error": "No text provided"}), 400
46
-
47
- author, score = predict_author(text)
48
- return jsonify({
49
- "author": author,
50
- "score": score if score is not None else 0
51
- })
52
-
53
- if __name__ == "__main__":
54
- app.run(host="0.0.0.0", port=5000)
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from sklearn.svm import SVC
4
  import joblib
5
+ import requests
6
+ import os
7
+ from fastapi import FastAPI
8
+ from pydantic import BaseModel
9
+ from typing import List
10
+ import tempfile
11
 
12
+ # ๐Ÿ”น Hugging Face repo info
13
+ HF_REPO = "ST-THOMAS-OF-AQUINAS/Document_verification"
14
+ MODEL_FILES = ["author1_svm.pkl", "author2_svm.pkl"] # replace with actual filenames
15
 
16
+ # ๐Ÿ”น Load tokenizer & BERT model
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
19
+ bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
20
+ bert_model.eval()
21
 
22
+ # ๐Ÿ”น Function to download file from HF Hub
23
+ def download_file(repo, filename):
24
+ url = f"https://huggingface.co/{repo}/resolve/main/{filename}"
25
+ response = requests.get(url)
26
+ tmp_path = os.path.join(tempfile.gettempdir(), filename)
27
+ with open(tmp_path, "wb") as f:
28
+ f.write(response.content)
29
+ return tmp_path
30
 
31
+ # ๐Ÿ”น Load SVM models dynamically from Hub
32
+ author_svms = {}
33
+ for file in MODEL_FILES:
34
+ author = file.replace("_svm.pkl", "")
35
+ path = download_file(HF_REPO, file)
36
+ clf = joblib.load(path)
37
+ author_svms[author] = clf
38
+ print(f"โœ… Loaded {len(author_svms)} author models")
39
 
40
+ # ๐Ÿ”น Text embedding
41
+ def embed_text(text):
42
  enc = tokenizer([text], return_tensors="pt", truncation=True, padding=True, max_length=256)
43
  enc = {k: v.to(device) for k, v in enc.items()}
44
  with torch.no_grad():
45
  outputs = bert_model(**enc)
46
+ pooled = outputs.last_hidden_state[:, 0, :].cpu().numpy()
47
+ return pooled
48
+
49
+ # ๐Ÿ”น Prediction function
50
+ def predict_author(text):
51
+ emb = embed_text(text)
52
+ predictions = {author: clf.predict(emb)[0] for author, clf in author_svms.items()}
53
+
54
+ accepted = [author for author, pred in predictions.items() if pred == 1]
55
+ if len(accepted) == 1:
56
+ return accepted[0]
57
+ elif len(accepted) > 1:
58
+ return accepted[0]
59
+ else:
60
+ return "Unknown"
61
+
62
+ # ๐Ÿ”น FastAPI app
63
+ app = FastAPI(title="Document Verification API")
64
+
65
+ class TextInput(BaseModel):
66
+ texts: List[str]
67
+
68
+ @app.post("/predict")
69
+ def predict(input_data: TextInput):
70
+ results = []
71
+ for txt in input_data.texts:
72
+ author = predict_author(txt)
73
+ results.append({"text": txt, "predicted_author": author})
74
+ return {"results": results}
75
+
76
+ @app.get("/health")
77
+ def health_check():
78
+ return {"status": "ok"}