Sina1138 commited on
Commit
97c36d1
·
1 Parent(s): 4c02957

Enhance interactive review processing by adding thread management for polarity and topic predictions; set optimal thread count for SLURM environments; update loss function to ignore padding tokens.

Browse files
dependencies/rsa_reranker.py CHANGED
@@ -85,7 +85,7 @@ class RSAReranking:
85
  y = [str(item) for item in list(y)]
86
  assert len(x) == len(y), "x and y must have the same length"
87
 
88
- loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
89
  batch_size = len(x)
90
 
91
  # Try to use cached tokenized sources for efficiency
 
85
  y = [str(item) for item in list(y)]
86
  assert len(x) == len(y), "x and y must have the same length"
87
 
88
+ loss_fn = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=self.tokenizer.pad_token_id)
89
  batch_size = len(x)
90
 
91
  # Try to use cached tokenized sources for efficiency
interface/Demo.py CHANGED
@@ -1278,52 +1278,71 @@ def format_general_rebuttals(rebuttal: str) -> str:
1278
  )
1279
 
1280
 
1281
- def process_interactive_reviews_fast(text1: str, text2: str, text3: str, text4: str, text5: str, text6: str, focus: str, rebuttal_str: str = "", progress=gr.Progress()) -> Tuple:
1282
  """
1283
  Fast processing: Polarity + Topic only (~3-5 sec on CPU).
1284
  RSA (agreement) runs in background.
1285
- Returns immediately with placeholder agreement sections that update when ready.
 
1286
  """
1287
  import time as _time
1288
  from dependencies.Glimpse_tokenizer import glimpse_tokenizer
1289
 
1290
  t_start = _time.time()
1291
- all_texts = [text1, text2, text3, text4, text5, text6]
1292
- active_texts = [t for t in all_texts if t and t.strip()]
1293
 
1294
- if len(active_texts) < 2:
1295
- raise ValueError("Please enter at least two reviews")
 
 
 
 
 
1296
 
1297
- # Step 1: Load models
1298
- progress(0.0, desc="Loading models...")
1299
- t0 = _time.time()
1300
- processor = get_interactive_processor()
1301
- print(f"[TIMING] get_interactive_processor: {_time.time() - t0:.1f}s")
1302
 
1303
- # Step 2: Tokenize
1304
- progress(0.10, desc="Tokenizing reviews...")
1305
- t0 = _time.time()
1306
- sentence_lists = [[s for s in glimpse_tokenizer(t) if s.strip()] for t in active_texts]
1307
- sentence_lists = [sl for sl in sentence_lists if sl]
1308
- print(f"[TIMING] Tokenization: {_time.time() - t0:.1f}s ({sum(len(sl) for sl in sentence_lists)} total sentences)")
1309
 
1310
- if len(sentence_lists) < 2:
1311
- raise ValueError("At least two reviews must have valid sentences")
1312
 
1313
- t0 = _time.time()
1314
- all_sentences = filter_and_clean_sentences(
1315
- list(set(s for sl in sentence_lists for s in sl))
1316
- )
1317
- print(f"[TIMING] filter_and_clean: {_time.time() - t0:.1f}s ({len(all_sentences)} unique sentences)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1318
 
1319
- # Step 3-4: Polarity + Topic (parallelize both models)
1320
- progress(0.30, desc="Predicting polarity and topics (parallel)...")
1321
- from concurrent.futures import ThreadPoolExecutor
 
 
1322
 
1323
- t0 = _time.time()
1324
- polarity_map = processor.predict_polarity(all_sentences)
1325
- topic_map = processor.predict_topic(all_sentences)
1326
- print(f"[TIMING] Polarity+Topic (sequential): {_time.time() - t0:.1f}s")
1327
  print(f"[TIMING] Fast processing total: {_time.time() - t_start:.1f}s")
1328
 
1329
  # Step 5: Format results as HTML with collapsible review cards
@@ -2212,7 +2231,10 @@ with gr.Blocks(
2212
  # State to hold raw rebuttal string (set by _show_raw_and_switch, consumed by process_interactive_reviews_fast)
2213
  interactive_rebuttal_state = gr.State("")
2214
 
2215
- _interactive_inputs = [review1_textbox, review2_textbox, review3_textbox, review4_textbox, review5_textbox, review6_textbox, focus_radio, interactive_rebuttal_state]
 
 
 
2216
 
2217
  # State to hold RSA computation results for async updates
2218
  rsa_computation_state = gr.State({})
@@ -2249,7 +2271,10 @@ with gr.Blocks(
2249
  no_title_state = gr.State("")
2250
 
2251
  def _show_raw_and_switch(r1, r2, r3, r4, r5, r6, rebuttal, title=""):
2252
- """Immediately switch to results view with raw tokenized reviews. No ML — just glimpse_tokenizer."""
 
 
 
2253
  from dependencies.Glimpse_tokenizer import glimpse_tokenizer
2254
  texts = [r1, r2, r3, r4, r5, r6]
2255
  active_count = sum(1 for t in texts if t and t.strip())
@@ -2289,6 +2314,44 @@ with gr.Blocks(
2289
  )
2290
 
2291
  title_text = title.strip() if title and title.strip() else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2292
  return (
2293
  *none_out, # none_text1..6
2294
  gr.update(visible=False), # input_section
@@ -2306,6 +2369,7 @@ with gr.Blocks(
2306
  active_count, # interactive_review_count
2307
  rebuttal or "", # interactive_rebuttal_state
2308
  gr.update(visible=False, value=""), # interactive_legend_html (reset on new submission)
 
2309
  )
2310
 
2311
  def _show_results_with_rebuttal(rebuttal, active_count):
@@ -2376,7 +2440,8 @@ with gr.Blocks(
2376
  rebuttal_for_review1, rebuttal_for_review2, rebuttal_for_review3,
2377
  rebuttal_for_review4, rebuttal_for_review5, rebuttal_for_review6,
2378
  interactive_rebuttal_display, interactive_review_count,
2379
- interactive_rebuttal_state, interactive_legend_html]
 
2380
  ).success(
2381
  fn=process_interactive_reviews_fast,
2382
  inputs=_interactive_inputs,
@@ -2414,7 +2479,8 @@ with gr.Blocks(
2414
  rebuttal_for_review1, rebuttal_for_review2, rebuttal_for_review3,
2415
  rebuttal_for_review4, rebuttal_for_review5, rebuttal_for_review6,
2416
  interactive_rebuttal_display, interactive_review_count,
2417
- interactive_rebuttal_state, interactive_legend_html]
 
2418
  ).success(
2419
  fn=process_interactive_reviews_fast,
2420
  inputs=_interactive_inputs,
@@ -2582,4 +2648,7 @@ with gr.Blocks(
2582
  outputs=_review_outputs,
2583
  )
2584
 
 
 
 
2585
  demo.launch(share=False)
 
1278
  )
1279
 
1280
 
1281
+ def process_interactive_reviews_fast(text1: str, text2: str, text3: str, text4: str, text5: str, text6: str, focus: str, rebuttal_str: str = "", thread_state=None, progress=gr.Progress()) -> Tuple:
1282
  """
1283
  Fast processing: Polarity + Topic only (~3-5 sec on CPU).
1284
  RSA (agreement) runs in background.
1285
+ If thread_state is provided, polarity+topic was already started during page transition —
1286
+ just wait for it instead of re-computing.
1287
  """
1288
  import time as _time
1289
  from dependencies.Glimpse_tokenizer import glimpse_tokenizer
1290
 
1291
  t_start = _time.time()
 
 
1292
 
1293
+ # Check if polarity+topic was already started in background by _show_raw_and_switch
1294
+ if thread_state and isinstance(thread_state, dict) and thread_state.get("thread"):
1295
+ bg_thread = thread_state["thread"]
1296
+ _result = thread_state["result"]
1297
+ sentence_lists = thread_state["sentence_lists"]
1298
+ active_texts = thread_state["active_texts"]
1299
+ all_sentences = thread_state["all_sentences"]
1300
 
1301
+ progress(0.30, desc="Predicting polarity and topics...")
 
 
 
 
1302
 
1303
+ # Wait for the background thread (may already be done!)
1304
+ bg_thread.join()
 
 
 
 
1305
 
1306
+ if _result.get("error"):
1307
+ raise _result["error"]
1308
 
1309
+ polarity_map = _result["polarity"]
1310
+ topic_map = _result["topic"]
1311
+ print(f"[TIMING] Polarity+Topic (from early-start thread): {_time.time() - t_start:.1f}s wait")
1312
+ else:
1313
+ # Fallback: compute from scratch (e.g. if thread_state was not passed)
1314
+ all_texts = [text1, text2, text3, text4, text5, text6]
1315
+ active_texts = [t for t in all_texts if t and t.strip()]
1316
+
1317
+ if len(active_texts) < 2:
1318
+ raise ValueError("Please enter at least two reviews")
1319
+
1320
+ progress(0.0, desc="Loading models...")
1321
+ t0 = _time.time()
1322
+ processor = get_interactive_processor()
1323
+ print(f"[TIMING] get_interactive_processor: {_time.time() - t0:.1f}s")
1324
+
1325
+ progress(0.10, desc="Tokenizing reviews...")
1326
+ t0 = _time.time()
1327
+ sentence_lists = [[s for s in glimpse_tokenizer(t) if s.strip()] for t in active_texts]
1328
+ sentence_lists = [sl for sl in sentence_lists if sl]
1329
+ print(f"[TIMING] Tokenization: {_time.time() - t0:.1f}s ({sum(len(sl) for sl in sentence_lists)} total sentences)")
1330
+
1331
+ if len(sentence_lists) < 2:
1332
+ raise ValueError("At least two reviews must have valid sentences")
1333
+
1334
+ t0 = _time.time()
1335
+ all_sentences = filter_and_clean_sentences(
1336
+ list(set(s for sl in sentence_lists for s in sl))
1337
+ )
1338
+ print(f"[TIMING] filter_and_clean: {_time.time() - t0:.1f}s ({len(all_sentences)} unique sentences)")
1339
 
1340
+ progress(0.30, desc="Predicting polarity and topics...")
1341
+ t0 = _time.time()
1342
+ polarity_map = processor.predict_polarity(all_sentences)
1343
+ topic_map = processor.predict_topic(all_sentences)
1344
+ print(f"[TIMING] Polarity+Topic (sequential): {_time.time() - t0:.1f}s")
1345
 
 
 
 
 
1346
  print(f"[TIMING] Fast processing total: {_time.time() - t_start:.1f}s")
1347
 
1348
  # Step 5: Format results as HTML with collapsible review cards
 
2231
  # State to hold raw rebuttal string (set by _show_raw_and_switch, consumed by process_interactive_reviews_fast)
2232
  interactive_rebuttal_state = gr.State("")
2233
 
2234
+ # State to hold background processing thread (polarity+topic starts during page transition)
2235
+ processing_thread_state = gr.State(None)
2236
+
2237
+ _interactive_inputs = [review1_textbox, review2_textbox, review3_textbox, review4_textbox, review5_textbox, review6_textbox, focus_radio, interactive_rebuttal_state, processing_thread_state]
2238
 
2239
  # State to hold RSA computation results for async updates
2240
  rsa_computation_state = gr.State({})
 
2271
  no_title_state = gr.State("")
2272
 
2273
  def _show_raw_and_switch(r1, r2, r3, r4, r5, r6, rebuttal, title=""):
2274
+ """Immediately switch to results view with raw tokenized reviews.
2275
+ Also kicks off polarity+topic in a background thread so processing
2276
+ overlaps with page transition rendering."""
2277
+ import time as _time
2278
  from dependencies.Glimpse_tokenizer import glimpse_tokenizer
2279
  texts = [r1, r2, r3, r4, r5, r6]
2280
  active_count = sum(1 for t in texts if t and t.strip())
 
2314
  )
2315
 
2316
  title_text = title.strip() if title and title.strip() else ""
2317
+
2318
+ # Start polarity+topic in background thread NOW, so processing
2319
+ # overlaps with Gradio rendering the page transition to the user.
2320
+ # By the time process_interactive_reviews_fast runs, work is already underway.
2321
+ active_texts = [t for t in texts if t and t.strip()]
2322
+ sentence_lists = []
2323
+ for t in active_texts:
2324
+ sents = [s for s in glimpse_tokenizer(t) if s.strip()]
2325
+ if sents:
2326
+ sentence_lists.append(sents)
2327
+ all_sentences = filter_and_clean_sentences(
2328
+ list(set(s for sl in sentence_lists for s in sl))
2329
+ )
2330
+
2331
+ processor = get_interactive_processor()
2332
+ _thread_result = {"polarity": None, "topic": None, "error": None}
2333
+
2334
+ def _run_polarity_topic():
2335
+ try:
2336
+ t0 = _time.time()
2337
+ _thread_result["polarity"] = processor.predict_polarity(all_sentences)
2338
+ _thread_result["topic"] = processor.predict_topic(all_sentences)
2339
+ print(f"[TIMING] Early polarity+topic thread done in {_time.time() - t0:.1f}s")
2340
+ except Exception as e:
2341
+ _thread_result["error"] = e
2342
+
2343
+ bg_thread = threading.Thread(target=_run_polarity_topic, daemon=True)
2344
+ bg_thread.start()
2345
+ print(f"[TIMING] Background polarity+topic thread started (page transitioning...)")
2346
+
2347
+ thread_state = {
2348
+ "thread": bg_thread,
2349
+ "result": _thread_result,
2350
+ "sentence_lists": sentence_lists,
2351
+ "active_texts": active_texts,
2352
+ "all_sentences": all_sentences,
2353
+ }
2354
+
2355
  return (
2356
  *none_out, # none_text1..6
2357
  gr.update(visible=False), # input_section
 
2369
  active_count, # interactive_review_count
2370
  rebuttal or "", # interactive_rebuttal_state
2371
  gr.update(visible=False, value=""), # interactive_legend_html (reset on new submission)
2372
+ thread_state, # processing_thread_state
2373
  )
2374
 
2375
  def _show_results_with_rebuttal(rebuttal, active_count):
 
2440
  rebuttal_for_review1, rebuttal_for_review2, rebuttal_for_review3,
2441
  rebuttal_for_review4, rebuttal_for_review5, rebuttal_for_review6,
2442
  interactive_rebuttal_display, interactive_review_count,
2443
+ interactive_rebuttal_state, interactive_legend_html,
2444
+ processing_thread_state]
2445
  ).success(
2446
  fn=process_interactive_reviews_fast,
2447
  inputs=_interactive_inputs,
 
2479
  rebuttal_for_review1, rebuttal_for_review2, rebuttal_for_review3,
2480
  rebuttal_for_review4, rebuttal_for_review5, rebuttal_for_review6,
2481
  interactive_rebuttal_display, interactive_review_count,
2482
+ interactive_rebuttal_state, interactive_legend_html,
2483
+ processing_thread_state]
2484
  ).success(
2485
  fn=process_interactive_reviews_fast,
2486
  inputs=_interactive_inputs,
 
2648
  outputs=_review_outputs,
2649
  )
2650
 
2651
+ # Pre-load interactive processor models at startup so first request isn't slow
2652
+ get_interactive_processor()
2653
+
2654
  demo.launch(share=False)
interface/interactive_processor.py CHANGED
@@ -46,6 +46,19 @@ def _try_bettertransformer(model):
46
  return model
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  class InteractiveReviewProcessor:
50
  """Process reviews through the same pipeline as preprocessed data."""
51
 
@@ -54,6 +67,9 @@ class InteractiveReviewProcessor:
54
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
55
  t_total = time.time()
56
 
 
 
 
57
  # Load summarization model (for RSA)
58
  t0 = time.time()
59
  rsa_model_name = "sshleifer/distilbart-cnn-12-3"
@@ -65,7 +81,7 @@ class InteractiveReviewProcessor:
65
  )
66
  self.rsa_tokenizer = AutoTokenizer.from_pretrained(rsa_model_name)
67
  self.rsa_model.to(self.device)
68
- self.rsa_model = _try_bettertransformer(self.rsa_model)
69
  self.rsa_model.eval()
70
  print(f"[TIMING] RSA model loaded in {time.time() - t0:.1f}s")
71
 
@@ -85,7 +101,9 @@ class InteractiveReviewProcessor:
85
  self.polarity_tokenizer = AutoTokenizer.from_pretrained(polarity_model_name)
86
  self.polarity_model = AutoModelForSequenceClassification.from_pretrained(polarity_model_name)
87
  self.polarity_model.to(self.device)
 
88
  self.polarity_model.eval()
 
89
  print(f"[TIMING] Polarity model loaded in {time.time() - t0:.1f}s")
90
 
91
  # Load topic model
@@ -102,8 +120,11 @@ class InteractiveReviewProcessor:
102
  self.topic_tokenizer = AutoTokenizer.from_pretrained(topic_model_name)
103
  self.topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model_name)
104
  self.topic_model.to(self.device)
 
105
  self.topic_model.eval()
 
106
  print(f"[TIMING] Topic model loaded in {time.time() - t0:.1f}s")
 
107
  print(f"[TIMING] All models loaded in {time.time() - t_total:.1f}s")
108
 
109
  # Topic ID to label mapping
 
46
  return model
47
 
48
 
49
+
50
+ def _set_optimal_threads():
51
+ """Set PyTorch thread count from SLURM allocation to avoid over/under-subscription."""
52
+ slurm_cpus = os.environ.get('SLURM_CPUS_PER_TASK') or os.environ.get('SLURM_CPUS_ON_NODE')
53
+ if slurm_cpus:
54
+ num_threads = int(slurm_cpus)
55
+ torch.set_num_threads(num_threads)
56
+ torch.set_num_interop_threads(min(num_threads, 4))
57
+ print(f"[THREADS] Set to {num_threads} (from SLURM)")
58
+ else:
59
+ print(f"[THREADS] Using PyTorch default: {torch.get_num_threads()}")
60
+
61
+
62
  class InteractiveReviewProcessor:
63
  """Process reviews through the same pipeline as preprocessed data."""
64
 
 
67
  self.device = torch.device(device if torch.cuda.is_available() else "cpu")
68
  t_total = time.time()
69
 
70
+ # Set optimal thread count for SLURM environment
71
+ _set_optimal_threads()
72
+
73
  # Load summarization model (for RSA)
74
  t0 = time.time()
75
  rsa_model_name = "sshleifer/distilbart-cnn-12-3"
 
81
  )
82
  self.rsa_tokenizer = AutoTokenizer.from_pretrained(rsa_model_name)
83
  self.rsa_model.to(self.device)
84
+ # BetterTransformer DISABLED for RSA — causes 2x slowdown on DistilBart CPU
85
  self.rsa_model.eval()
86
  print(f"[TIMING] RSA model loaded in {time.time() - t0:.1f}s")
87
 
 
101
  self.polarity_tokenizer = AutoTokenizer.from_pretrained(polarity_model_name)
102
  self.polarity_model = AutoModelForSequenceClassification.from_pretrained(polarity_model_name)
103
  self.polarity_model.to(self.device)
104
+ self.polarity_model = _try_bettertransformer(self.polarity_model)
105
  self.polarity_model.eval()
106
+
107
  print(f"[TIMING] Polarity model loaded in {time.time() - t0:.1f}s")
108
 
109
  # Load topic model
 
120
  self.topic_tokenizer = AutoTokenizer.from_pretrained(topic_model_name)
121
  self.topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model_name)
122
  self.topic_model.to(self.device)
123
+ self.topic_model = _try_bettertransformer(self.topic_model)
124
  self.topic_model.eval()
125
+
126
  print(f"[TIMING] Topic model loaded in {time.time() - t0:.1f}s")
127
+
128
  print(f"[TIMING] All models loaded in {time.time() - t_total:.1f}s")
129
 
130
  # Topic ID to label mapping