ST-THOMAS-OF-AQUINAS commited on
Commit
47eba64
ยท
verified ยท
1 Parent(s): b6bfd9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -12
app.py CHANGED
@@ -7,18 +7,24 @@ from fastapi import FastAPI
7
  from pydantic import BaseModel
8
  from typing import List
9
 
10
- # ๐Ÿ”น Ensure Transformers cache is writable (optional)
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
12
  os.environ["HF_HOME"] = "/tmp/hf_cache"
 
13
 
14
- # ๐Ÿ”น Load tokenizer & BERT model
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
17
- bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
18
- bert_model.eval()
 
 
 
 
 
19
 
20
  # ๐Ÿ”น Load SVM models from `models/` folder
21
- MODEL_DIR = "models"
22
  MODEL_FILES = ["Dean of students_svm.pkl", "Registra_svm.pkl"]
23
  author_svms = {}
24
 
@@ -27,14 +33,19 @@ for file in MODEL_FILES:
27
  if not os.path.exists(path):
28
  raise FileNotFoundError(f"Model file not found: {path}")
29
  author = file.replace("_svm.pkl", "")
30
- clf = joblib.load(path)
31
- author_svms[author] = clf
 
 
 
32
 
33
  print(f"โœ… Loaded {len(author_svms)} author models from {MODEL_DIR}")
34
 
35
- # ๐Ÿ”น Text embedding
36
  def embed_text(text: str):
37
- enc = tokenizer([text], return_tensors="pt", truncation=True, padding=True, max_length=256)
 
 
38
  enc = {k: v.to(device) for k, v in enc.items()}
39
  with torch.no_grad():
40
  outputs = bert_model(**enc)
@@ -44,13 +55,19 @@ def embed_text(text: str):
44
  # ๐Ÿ”น Prediction function
45
  def predict_author(text: str):
46
  emb = embed_text(text)
47
- predictions = {author: clf.predict(emb)[0] for author, clf in author_svms.items()}
 
 
 
 
 
 
48
 
49
  accepted = [author for author, pred in predictions.items() if pred == 1]
50
  if len(accepted) == 1:
51
  return accepted[0]
52
  elif len(accepted) > 1:
53
- return accepted[0]
54
  else:
55
  return "Unknown"
56
 
 
7
  from pydantic import BaseModel
8
  from typing import List
9
 
10
+ # ๐Ÿ”น Ensure Transformers cache is writable
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
12
  os.environ["HF_HOME"] = "/tmp/hf_cache"
13
+ os.makedirs("/tmp/hf_cache", exist_ok=True)
14
 
15
+ # ๐Ÿ”น Device setup
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # ๐Ÿ”น Load tokenizer & BERT model with error handling
19
+ try:
20
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
21
+ bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
22
+ bert_model.eval()
23
+ except Exception as e:
24
+ raise RuntimeError(f"Failed to load BERT model: {e}")
25
 
26
  # ๐Ÿ”น Load SVM models from `models/` folder
27
+ MODEL_DIR = "models/"
28
  MODEL_FILES = ["Dean of students_svm.pkl", "Registra_svm.pkl"]
29
  author_svms = {}
30
 
 
33
  if not os.path.exists(path):
34
  raise FileNotFoundError(f"Model file not found: {path}")
35
  author = file.replace("_svm.pkl", "")
36
+ try:
37
+ clf = joblib.load(path)
38
+ author_svms[author] = clf
39
+ except Exception as e:
40
+ raise RuntimeError(f"Failed to load SVM model {file}: {e}")
41
 
42
  print(f"โœ… Loaded {len(author_svms)} author models from {MODEL_DIR}")
43
 
44
+ # ๐Ÿ”น Text embedding function
45
  def embed_text(text: str):
46
+ enc = tokenizer(
47
+ [text], return_tensors="pt", truncation=True, padding=True, max_length=256
48
+ )
49
  enc = {k: v.to(device) for k, v in enc.items()}
50
  with torch.no_grad():
51
  outputs = bert_model(**enc)
 
55
  # ๐Ÿ”น Prediction function
56
  def predict_author(text: str):
57
  emb = embed_text(text)
58
+ predictions = {}
59
+ for author, clf in author_svms.items():
60
+ try:
61
+ predictions[author] = clf.predict(emb)[0]
62
+ except Exception as e:
63
+ predictions[author] = -1 # mark as failed
64
+ print(f"โš ๏ธ Prediction failed for {author}: {e}")
65
 
66
  accepted = [author for author, pred in predictions.items() if pred == 1]
67
  if len(accepted) == 1:
68
  return accepted[0]
69
  elif len(accepted) > 1:
70
+ return accepted[0] # pick first if multiple
71
  else:
72
  return "Unknown"
73