Sina1138 commited on
Commit
e5ab9ff
·
1 Parent(s): 4f76bec

Update model loading to prioritize local trained models for improved accuracy; adjust configuration for model paths and fallback options

Browse files
dependencies/scoring_utils.py CHANGED
@@ -184,31 +184,44 @@ def validate_input_file(input_path: Path, required_columns: list) -> pd.DataFram
184
  def load_polarity_model(model_variant: str, base_dir: Path, device: str = "cuda"):
185
  """
186
  Factory function to load polarity model by variant name.
187
-
188
  Supported variants:
189
- - "scibert": scibert/scibert_polarity/final_model
190
- - "deberta": alternative_polarity/deberta/final_model
191
- - "scideberta": alternative_polarity/scideberta/final_model
192
-
 
 
193
  Args:
194
  model_variant: Name of model variant
195
  base_dir: Base directory of project
196
  device: Device to load onto
197
-
198
  Returns:
199
  Tuple of (tokenizer, model, device_obj)
200
-
201
  Raises:
202
  ValueError: If model_variant not supported
203
  FileNotFoundError: If model directory doesn't exist
204
  """
 
205
  variant_map = {
206
- "scibert": base_dir / "scibert" / "scibert_polarity" / "final_model",
207
- "deberta": base_dir / "alternative_polarity" / "deberta" / "deberta_v3_base_polarity_final_model",
208
- "scideberta": base_dir / "alternative_polarity" / "scideberta" / "scideberta_full_polarity_final_model",
 
 
 
 
 
 
209
  }
210
  hub_fallback_map = {
211
  "scibert": "Sina1138/Scibert_polarity_Review",
 
 
 
 
212
  }
213
 
214
  if model_variant not in variant_map:
@@ -224,31 +237,44 @@ def load_polarity_model(model_variant: str, base_dir: Path, device: str = "cuda"
224
  def load_topic_model(model_variant: str, base_dir: Path, device: str = "cuda"):
225
  """
226
  Factory function to load topic model by variant name.
227
-
228
  Supported variants:
229
- - "scibert": scibert/scibert_topic/final_model
230
- - "deberta": alternative_topic/deberta/final_model
231
- - "scideberta": alternative_topic/scideberta/final_model
232
-
 
 
233
  Args:
234
  model_variant: Name of model variant
235
  base_dir: Base directory of project
236
  device: Device to load onto
237
-
238
  Returns:
239
  Tuple of (tokenizer, model, device_obj)
240
-
241
  Raises:
242
  ValueError: If model_variant not supported
243
  FileNotFoundError: If model directory doesn't exist
244
  """
 
245
  variant_map = {
246
- "scibert": base_dir / "scibert" / "scibert_topic" / "final_model",
247
- "deberta": base_dir / "alternative_topic" / "deberta" / "final_model",
248
- "scideberta": base_dir / "alternative_topic" / "scideberta" / "final_model",
 
 
 
 
 
 
249
  }
250
  hub_fallback_map = {
251
- "scibert": "Sina1138/SciDeberta_Review",
 
 
 
 
252
  }
253
 
254
  if model_variant not in variant_map:
 
184
  def load_polarity_model(model_variant: str, base_dir: Path, device: str = "cuda"):
185
  """
186
  Factory function to load polarity model by variant name.
187
+
188
  Supported variants:
189
+ - "scibert": scibert/scibert_polarity/final_model (F1=0.724 baseline)
190
+ - "deberta": training/outputs/deberta_polarity/final_model (F1=0.764, +5.5% - RECOMMENDED)
191
+ - "deberta_v3_small": training/outputs/deberta_v3_small_polarity/final_model (F1=0.754)
192
+ - "modernbert": training/outputs/modernbert_polarity/final_model (F1=0.741)
193
+ - "scideberta": training/outputs/scideberta_polarity/final_model (F1=0.737)
194
+
195
  Args:
196
  model_variant: Name of model variant
197
  base_dir: Base directory of project
198
  device: Device to load onto
199
+
200
  Returns:
201
  Tuple of (tokenizer, model, device_obj)
202
+
203
  Raises:
204
  ValueError: If model_variant not supported
205
  FileNotFoundError: If model directory doesn't exist
206
  """
207
+ # Feb 2026: New trained models from training/outputs/ (standardized comparison)
208
  variant_map = {
209
+ "scibert": base_dir / "training" / "outputs" / "scibert_polarity" / "final_model",
210
+ "deberta": base_dir / "training" / "outputs" / "deberta_polarity" / "final_model", # BEST: F1=0.764
211
+ "deberta_v3_small": base_dir / "training" / "outputs" / "deberta_v3_small_polarity" / "final_model",
212
+ "modernbert": base_dir / "training" / "outputs" / "modernbert_polarity" / "final_model",
213
+ "scideberta": base_dir / "training" / "outputs" / "scideberta_polarity" / "final_model",
214
+ # Legacy models (pre-Feb 2026, kept for backwards compatibility)
215
+ "scibert_legacy": base_dir / "scibert" / "scibert_polarity" / "final_model",
216
+ "deberta_legacy": base_dir / "alternative_polarity" / "deberta" / "deberta_v3_base_polarity_final_model",
217
+ "scideberta_legacy": base_dir / "alternative_polarity" / "scideberta" / "scideberta_full_polarity_final_model",
218
  }
219
  hub_fallback_map = {
220
  "scibert": "Sina1138/Scibert_polarity_Review",
221
+ "scideberta": "KISTI-AI/Scideberta-full", # Needs fine-tuning
222
+ "modernbert": "answerdotai/ModernBERT-base", # Needs fine-tuning
223
+ "deberta": "microsoft/deberta-v3-base", # Needs fine-tuning
224
+ "deberta_v3_small": "microsoft/deberta-v3-small", # Needs fine-tuning
225
  }
226
 
227
  if model_variant not in variant_map:
 
237
  def load_topic_model(model_variant: str, base_dir: Path, device: str = "cuda"):
238
  """
239
  Factory function to load topic model by variant name.
240
+
241
  Supported variants:
242
+ - "scideberta": training/outputs/scideberta_topic/final_model (F1=0.478 - BEST, RECOMMENDED)
243
+ - "deberta": training/outputs/deberta_topic/final_model (F1=0.450)
244
+ - "scibert": training/outputs/scibert_topic/final_model (F1=0.442)
245
+ - "deberta_v3_small": training/outputs/deberta_v3_small_topic/final_model (F1=0.381)
246
+ - "modernbert": training/outputs/modernbert_topic/final_model (F1=0.376)
247
+
248
  Args:
249
  model_variant: Name of model variant
250
  base_dir: Base directory of project
251
  device: Device to load onto
252
+
253
  Returns:
254
  Tuple of (tokenizer, model, device_obj)
255
+
256
  Raises:
257
  ValueError: If model_variant not supported
258
  FileNotFoundError: If model directory doesn't exist
259
  """
260
+ # Feb 2026: New trained models from training/outputs/ (standardized comparison)
261
  variant_map = {
262
+ "scideberta": base_dir / "training" / "outputs" / "scideberta_topic" / "final_model", # BEST: F1=0.478
263
+ "deberta": base_dir / "training" / "outputs" / "deberta_topic" / "final_model",
264
+ "scibert": base_dir / "training" / "outputs" / "scibert_topic" / "final_model",
265
+ "deberta_v3_small": base_dir / "training" / "outputs" / "deberta_v3_small_topic" / "final_model",
266
+ "modernbert": base_dir / "training" / "outputs" / "modernbert_topic" / "final_model",
267
+ # Legacy models (pre-Feb 2026, kept for backwards compatibility)
268
+ "scibert_legacy": base_dir / "scibert" / "scibert_topic" / "final_model",
269
+ "deberta_legacy": base_dir / "alternative_topic" / "deberta" / "final_model",
270
+ "scideberta_legacy": base_dir / "alternative_topic" / "scideberta" / "final_model",
271
  }
272
  hub_fallback_map = {
273
+ "scideberta": "Sina1138/SciDeberta_Review", # Production HuggingFace model
274
+ "scibert": "allenai/scibert_scivocab_uncased", # Needs fine-tuning
275
+ "deberta": "microsoft/deberta-v3-base", # Needs fine-tuning
276
+ "deberta_v3_small": "microsoft/deberta-v3-small", # Needs fine-tuning
277
+ "modernbert": "answerdotai/ModernBERT-base", # Needs fine-tuning
278
  }
279
 
280
  if model_variant not in variant_map:
interface/interactive_processor.py CHANGED
@@ -64,14 +64,32 @@ class InteractiveReviewProcessor:
64
  self.rsa_model.eval()
65
 
66
  # Load polarity model
67
- polarity_model_name = "Sina1138/Scibert_polarity_Review"
 
 
 
 
 
 
 
 
 
 
68
  self.polarity_tokenizer = AutoTokenizer.from_pretrained(polarity_model_name)
69
  self.polarity_model = AutoModelForSequenceClassification.from_pretrained(polarity_model_name)
70
  self.polarity_model.to(self.device)
71
  self.polarity_model.eval()
72
 
73
  # Load topic model
74
- topic_model_name = "Sina1138/SciDeberta_Review"
 
 
 
 
 
 
 
 
75
  self.topic_tokenizer = AutoTokenizer.from_pretrained(topic_model_name)
76
  self.topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model_name)
77
  self.topic_model.to(self.device)
 
64
  self.rsa_model.eval()
65
 
66
  # Load polarity model
67
+ # Option A (Feb 2026): DeBERTa-v3-base for +5.5% F1 improvement (0.764 vs 0.724 SciBERT)
68
+ # Try local trained model first, fall back to HuggingFace
69
+ polarity_model_local = BASE_DIR / "training" / "outputs" / "deberta_polarity" / "final_model"
70
+ if polarity_model_local.exists() and (polarity_model_local / "config.json").exists():
71
+ polarity_model_name = str(polarity_model_local)
72
+ print(f"Loading polarity model from local trained model: {polarity_model_name}")
73
+ else:
74
+ # Fallback: will need to upload fine-tuned model or use legacy SciBERT
75
+ polarity_model_name = "Sina1138/Scibert_polarity_Review" # Legacy SciBERT
76
+ print(f"Local model not found, using legacy SciBERT: {polarity_model_name}")
77
+
78
  self.polarity_tokenizer = AutoTokenizer.from_pretrained(polarity_model_name)
79
  self.polarity_model = AutoModelForSequenceClassification.from_pretrained(polarity_model_name)
80
  self.polarity_model.to(self.device)
81
  self.polarity_model.eval()
82
 
83
  # Load topic model
84
+ # SciDeBERTa maintains best performance (F1=0.478)
85
+ topic_model_local = BASE_DIR / "training" / "outputs" / "scideberta_topic" / "final_model"
86
+ if topic_model_local.exists() and (topic_model_local / "config.json").exists():
87
+ topic_model_name = str(topic_model_local)
88
+ print(f"Loading topic model from local trained model: {topic_model_name}")
89
+ else:
90
+ topic_model_name = "Sina1138/SciDeberta_Review" # Production HuggingFace model
91
+ print(f"Using HuggingFace topic model: {topic_model_name}")
92
+
93
  self.topic_tokenizer = AutoTokenizer.from_pretrained(topic_model_name)
94
  self.topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model_name)
95
  self.topic_model.to(self.device)
pipeline/config.py CHANGED
@@ -28,9 +28,26 @@ class Config:
28
  OPENREVIEW_BASE_URL = 'https://api2.openreview.net'
29
  VENUE_TEMPLATE = 'ICLR.cc/{year}/Conference'
30
 
31
- # Model paths (HuggingFace)
32
- POLARITY_MODEL = "Sina1138/Scibert_polarity_Review"
33
- TOPIC_MODEL = "Sina1138/SciDeberta_Review"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  RSA_MODEL = "sshleifer/distilbart-cnn-12-3" # For GLIMPSE
35
 
36
  @classmethod
 
28
  OPENREVIEW_BASE_URL = 'https://api2.openreview.net'
29
  VENUE_TEMPLATE = 'ICLR.cc/{year}/Conference'
30
 
31
+ # Model paths
32
+ # Option A (Maximize Accuracy): DeBERTa polarity + SciDeBERTa topic - Feb 2026 upgrade
33
+ # Polarity: DeBERTa-v3-base (F1=0.764, +5.5% vs SciBERT baseline 0.724)
34
+ # Topic: SciDeBERTa (F1=0.478, maintains lead)
35
+
36
+ # Local trained models (preferred for production after validation)
37
+ POLARITY_MODEL_LOCAL = BASE_DIR / "training" / "outputs" / "deberta_polarity" / "final_model"
38
+ TOPIC_MODEL_LOCAL = BASE_DIR / "training" / "outputs" / "scideberta_topic" / "final_model"
39
+
40
+ # HuggingFace fallbacks (if local models not available)
41
+ POLARITY_MODEL_HUB = "Sina1138/Scibert_polarity_Review" # Legacy SciBERT (until fine-tuned DeBERTa is uploaded to Hub)
42
+ TOPIC_MODEL_HUB = "Sina1138/SciDeberta_Review" # Current production model
43
+
44
+ # Legacy models (SciBERT baseline, kept for reference)
45
+ POLARITY_MODEL_LEGACY = "Sina1138/Scibert_polarity_Review" # F1=0.724
46
+
47
+ # Use local models if available, otherwise fall back to hub
48
+ POLARITY_MODEL = str(POLARITY_MODEL_LOCAL) if POLARITY_MODEL_LOCAL.exists() else POLARITY_MODEL_HUB
49
+ TOPIC_MODEL = str(TOPIC_MODEL_LOCAL) if TOPIC_MODEL_LOCAL.exists() else TOPIC_MODEL_HUB
50
+
51
  RSA_MODEL = "sshleifer/distilbart-cnn-12-3" # For GLIMPSE
52
 
53
  @classmethod
pipeline/scored_reviews_builder.py CHANGED
@@ -6,11 +6,15 @@ from pathlib import Path
6
  from tqdm import tqdm
7
  import json
8
 
9
- BASE_DIR = Path(__file__).resolve().parent.parent
10
- sys.path.insert(0, str(BASE_DIR))
 
11
 
 
12
  from dependencies.Glimpse_tokenizer import glimpse_tokenizer
13
 
 
 
14
  # def tokenize_sentences(text: str) -> list:
15
  # # same tokenization as in the original glimpse code
16
  # text = text.replace('-----', '\n')
 
6
  from tqdm import tqdm
7
  import json
8
 
9
+ _dir = Path(__file__).resolve().parent
10
+ sys.path.insert(0, str(_dir))
11
+ sys.path.insert(0, str(_dir.parent))
12
 
13
+ from config import Config
14
  from dependencies.Glimpse_tokenizer import glimpse_tokenizer
15
 
16
+ BASE_DIR = Config.BASE_DIR
17
+
18
  # def tokenize_sentences(text: str) -> list:
19
  # # same tokenization as in the original glimpse code
20
  # text = text.replace('-----', '\n')