Sbhatti33 commited on
Commit
e2fea07
·
verified ·
1 Parent(s): f5cc420

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -8,26 +8,26 @@ from sentence_transformers import SentenceTransformer
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  from torch.nn.functional import softmax
10
 
11
- # ========== Model Paths ==========
12
- MODEL_DIR = "model"
13
- SBERT_PATH = os.path.join(MODEL_DIR, "sbert_model")
14
- SBERT_CLF_PATH = os.path.join(MODEL_DIR, "sbert_mlp.pkl")
15
- DEBERTA_PATH = os.path.join(MODEL_DIR, "deberta_model")
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
 
19
  # ========== Load Models ==========
20
- # Load SBERT sentence embedding model
21
- sbert = SentenceTransformer(SBERT_PATH, device=device)
22
 
23
- # Load trained SBERT classifier (MLP)
24
- clf = joblib.load(SBERT_CLF_PATH)
 
 
 
 
25
 
26
  # Load DeBERTa model and tokenizer
27
- deberta_tokenizer = AutoTokenizer.from_pretrained(DEBERTA_PATH)
28
- deberta_model = AutoModelForSequenceClassification.from_pretrained(DEBERTA_PATH).to(device)
29
  deberta_model.eval()
30
 
 
31
  # ========== Class Names ==========
32
  label_map = {0: "believer", 1: "neutral", 2: "denier"}
33
  class_names = ["believer", "neutral", "denier"]
 
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  from torch.nn.functional import softmax
10
 
11
+ from huggingface_hub import hf_hub_download
 
 
 
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+
16
  # ========== Load Models ==========
 
 
17
 
18
+ # Load SBERT encoder
19
+ sbert = SentenceTransformer("Sbhatti33/sbert-stance-model")
20
+
21
+ # Load SBERT classifier
22
+ clf_path = hf_hub_download(repo_id="Sbhatti33/sbert-stance-model", filename="sbert_mlp.pkl")
23
+ clf = joblib.load(clf_path)
24
 
25
  # Load DeBERTa model and tokenizer
26
+ deberta_tokenizer = AutoTokenizer.from_pretrained("Sbhatti33/deberta-stance-model")
27
+ deberta_model = AutoModelForSequenceClassification.from_pretrained("Sbhatti33/deberta-stance-model").to(device)
28
  deberta_model.eval()
29
 
30
+
31
  # ========== Class Names ==========
32
  label_map = {0: "believer", 1: "neutral", 2: "denier"}
33
  class_names = ["believer", "neutral", "denier"]