ST-THOMAS-OF-AQUINAS commited on
Commit
d89415b
ยท
verified ยท
1 Parent(s): 1741dfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -7,29 +7,33 @@ from fastapi import FastAPI
7
  from pydantic import BaseModel
8
  from typing import List
9
 
10
- # ๐Ÿ”น Set device
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
12
 
13
  # ๐Ÿ”น Load tokenizer & BERT model
 
14
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
15
  bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
16
  bert_model.eval()
17
 
18
- # ๐Ÿ”น Path where SVM models are stored (same directory or "models/" subfolder)
19
- MODEL_DIR = "./models" # or "./models" if you put models in a subfolder
20
-
21
- # ๐Ÿ”น Load SVM models dynamically from local directory
22
  author_svms = {}
23
- for file in os.listdir(MODEL_DIR):
24
- if file.endswith("_svm.pkl"):
25
- author = file.replace("_svm.pkl", "")
26
- clf = joblib.load(os.path.join(MODEL_DIR, file))
27
- author_svms[author] = clf
28
 
29
- print(f"โœ… Loaded {len(author_svms)} author models")
 
 
 
 
 
 
 
 
30
 
31
  # ๐Ÿ”น Text embedding
32
- def embed_text(text):
33
  enc = tokenizer([text], return_tensors="pt", truncation=True, padding=True, max_length=256)
34
  enc = {k: v.to(device) for k, v in enc.items()}
35
  with torch.no_grad():
@@ -38,7 +42,7 @@ def embed_text(text):
38
  return pooled
39
 
40
  # ๐Ÿ”น Prediction function
41
- def predict_author(text):
42
  emb = embed_text(text)
43
  predictions = {author: clf.predict(emb)[0] for author, clf in author_svms.items()}
44
 
 
7
  from pydantic import BaseModel
8
  from typing import List
9
 
10
+ # ๐Ÿ”น Ensure Transformers cache is writable (optional)
11
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
12
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
13
 
14
  # ๐Ÿ”น Load tokenizer & BERT model
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
17
  bert_model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
18
  bert_model.eval()
19
 
20
+ # ๐Ÿ”น Load SVM models from `models/` folder
21
+ MODEL_DIR = "models"
22
+ MODEL_FILES = ["Dean of students_svm.pkl", "Registra_svm.pkl"]
 
23
  author_svms = {}
 
 
 
 
 
24
 
25
+ for file in MODEL_FILES:
26
+ path = os.path.join(MODEL_DIR, file)
27
+ if not os.path.exists(path):
28
+ raise FileNotFoundError(f"Model file not found: {path}")
29
+ author = file.replace("_svm.pkl", "")
30
+ clf = joblib.load(path)
31
+ author_svms[author] = clf
32
+
33
+ print(f"โœ… Loaded {len(author_svms)} author models from {MODEL_DIR}")
34
 
35
  # ๐Ÿ”น Text embedding
36
+ def embed_text(text: str):
37
  enc = tokenizer([text], return_tensors="pt", truncation=True, padding=True, max_length=256)
38
  enc = {k: v.to(device) for k, v in enc.items()}
39
  with torch.no_grad():
 
42
  return pooled
43
 
44
  # ๐Ÿ”น Prediction function
45
+ def predict_author(text: str):
46
  emb = embed_text(text)
47
  predictions = {author: clf.predict(emb)[0] for author, clf in author_svms.items()}
48