Sina1138 commited on
Commit ·
97c36d1
1
Parent(s): 4c02957
Enhance interactive review processing by adding thread management for polarity and topic predictions; set optimal thread count for SLURM environments; update loss function to ignore padding tokens.
Browse files- dependencies/rsa_reranker.py +1 -1
- interface/Demo.py +104 -35
- interface/interactive_processor.py +22 -1
dependencies/rsa_reranker.py
CHANGED
|
@@ -85,7 +85,7 @@ class RSAReranking:
|
|
| 85 |
y = [str(item) for item in list(y)]
|
| 86 |
assert len(x) == len(y), "x and y must have the same length"
|
| 87 |
|
| 88 |
-
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
|
| 89 |
batch_size = len(x)
|
| 90 |
|
| 91 |
# Try to use cached tokenized sources for efficiency
|
|
|
|
| 85 |
y = [str(item) for item in list(y)]
|
| 86 |
assert len(x) == len(y), "x and y must have the same length"
|
| 87 |
|
| 88 |
+
loss_fn = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=self.tokenizer.pad_token_id)
|
| 89 |
batch_size = len(x)
|
| 90 |
|
| 91 |
# Try to use cached tokenized sources for efficiency
|
interface/Demo.py
CHANGED
|
@@ -1278,52 +1278,71 @@ def format_general_rebuttals(rebuttal: str) -> str:
|
|
| 1278 |
)
|
| 1279 |
|
| 1280 |
|
| 1281 |
-
def process_interactive_reviews_fast(text1: str, text2: str, text3: str, text4: str, text5: str, text6: str, focus: str, rebuttal_str: str = "", progress=gr.Progress()) -> Tuple:
|
| 1282 |
"""
|
| 1283 |
Fast processing: Polarity + Topic only (~3-5 sec on CPU).
|
| 1284 |
RSA (agreement) runs in background.
|
| 1285 |
-
|
|
|
|
| 1286 |
"""
|
| 1287 |
import time as _time
|
| 1288 |
from dependencies.Glimpse_tokenizer import glimpse_tokenizer
|
| 1289 |
|
| 1290 |
t_start = _time.time()
|
| 1291 |
-
all_texts = [text1, text2, text3, text4, text5, text6]
|
| 1292 |
-
active_texts = [t for t in all_texts if t and t.strip()]
|
| 1293 |
|
| 1294 |
-
if
|
| 1295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1296 |
|
| 1297 |
-
|
| 1298 |
-
progress(0.0, desc="Loading models...")
|
| 1299 |
-
t0 = _time.time()
|
| 1300 |
-
processor = get_interactive_processor()
|
| 1301 |
-
print(f"[TIMING] get_interactive_processor: {_time.time() - t0:.1f}s")
|
| 1302 |
|
| 1303 |
-
|
| 1304 |
-
|
| 1305 |
-
t0 = _time.time()
|
| 1306 |
-
sentence_lists = [[s for s in glimpse_tokenizer(t) if s.strip()] for t in active_texts]
|
| 1307 |
-
sentence_lists = [sl for sl in sentence_lists if sl]
|
| 1308 |
-
print(f"[TIMING] Tokenization: {_time.time() - t0:.1f}s ({sum(len(sl) for sl in sentence_lists)} total sentences)")
|
| 1309 |
|
| 1310 |
-
|
| 1311 |
-
|
| 1312 |
|
| 1313 |
-
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1318 |
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
|
|
|
|
|
|
| 1322 |
|
| 1323 |
-
t0 = _time.time()
|
| 1324 |
-
polarity_map = processor.predict_polarity(all_sentences)
|
| 1325 |
-
topic_map = processor.predict_topic(all_sentences)
|
| 1326 |
-
print(f"[TIMING] Polarity+Topic (sequential): {_time.time() - t0:.1f}s")
|
| 1327 |
print(f"[TIMING] Fast processing total: {_time.time() - t_start:.1f}s")
|
| 1328 |
|
| 1329 |
# Step 5: Format results as HTML with collapsible review cards
|
|
@@ -2212,7 +2231,10 @@ with gr.Blocks(
|
|
| 2212 |
# State to hold raw rebuttal string (set by _show_raw_and_switch, consumed by process_interactive_reviews_fast)
|
| 2213 |
interactive_rebuttal_state = gr.State("")
|
| 2214 |
|
| 2215 |
-
|
|
|
|
|
|
|
|
|
|
| 2216 |
|
| 2217 |
# State to hold RSA computation results for async updates
|
| 2218 |
rsa_computation_state = gr.State({})
|
|
@@ -2249,7 +2271,10 @@ with gr.Blocks(
|
|
| 2249 |
no_title_state = gr.State("")
|
| 2250 |
|
| 2251 |
def _show_raw_and_switch(r1, r2, r3, r4, r5, r6, rebuttal, title=""):
|
| 2252 |
-
"""Immediately switch to results view with raw tokenized reviews.
|
|
|
|
|
|
|
|
|
|
| 2253 |
from dependencies.Glimpse_tokenizer import glimpse_tokenizer
|
| 2254 |
texts = [r1, r2, r3, r4, r5, r6]
|
| 2255 |
active_count = sum(1 for t in texts if t and t.strip())
|
|
@@ -2289,6 +2314,44 @@ with gr.Blocks(
|
|
| 2289 |
)
|
| 2290 |
|
| 2291 |
title_text = title.strip() if title and title.strip() else ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2292 |
return (
|
| 2293 |
*none_out, # none_text1..6
|
| 2294 |
gr.update(visible=False), # input_section
|
|
@@ -2306,6 +2369,7 @@ with gr.Blocks(
|
|
| 2306 |
active_count, # interactive_review_count
|
| 2307 |
rebuttal or "", # interactive_rebuttal_state
|
| 2308 |
gr.update(visible=False, value=""), # interactive_legend_html (reset on new submission)
|
|
|
|
| 2309 |
)
|
| 2310 |
|
| 2311 |
def _show_results_with_rebuttal(rebuttal, active_count):
|
|
@@ -2376,7 +2440,8 @@ with gr.Blocks(
|
|
| 2376 |
rebuttal_for_review1, rebuttal_for_review2, rebuttal_for_review3,
|
| 2377 |
rebuttal_for_review4, rebuttal_for_review5, rebuttal_for_review6,
|
| 2378 |
interactive_rebuttal_display, interactive_review_count,
|
| 2379 |
-
interactive_rebuttal_state, interactive_legend_html
|
|
|
|
| 2380 |
).success(
|
| 2381 |
fn=process_interactive_reviews_fast,
|
| 2382 |
inputs=_interactive_inputs,
|
|
@@ -2414,7 +2479,8 @@ with gr.Blocks(
|
|
| 2414 |
rebuttal_for_review1, rebuttal_for_review2, rebuttal_for_review3,
|
| 2415 |
rebuttal_for_review4, rebuttal_for_review5, rebuttal_for_review6,
|
| 2416 |
interactive_rebuttal_display, interactive_review_count,
|
| 2417 |
-
interactive_rebuttal_state, interactive_legend_html
|
|
|
|
| 2418 |
).success(
|
| 2419 |
fn=process_interactive_reviews_fast,
|
| 2420 |
inputs=_interactive_inputs,
|
|
@@ -2582,4 +2648,7 @@ with gr.Blocks(
|
|
| 2582 |
outputs=_review_outputs,
|
| 2583 |
)
|
| 2584 |
|
|
|
|
|
|
|
|
|
|
| 2585 |
demo.launch(share=False)
|
|
|
|
| 1278 |
)
|
| 1279 |
|
| 1280 |
|
| 1281 |
+
def process_interactive_reviews_fast(text1: str, text2: str, text3: str, text4: str, text5: str, text6: str, focus: str, rebuttal_str: str = "", thread_state=None, progress=gr.Progress()) -> Tuple:
|
| 1282 |
"""
|
| 1283 |
Fast processing: Polarity + Topic only (~3-5 sec on CPU).
|
| 1284 |
RSA (agreement) runs in background.
|
| 1285 |
+
If thread_state is provided, polarity+topic was already started during page transition —
|
| 1286 |
+
just wait for it instead of re-computing.
|
| 1287 |
"""
|
| 1288 |
import time as _time
|
| 1289 |
from dependencies.Glimpse_tokenizer import glimpse_tokenizer
|
| 1290 |
|
| 1291 |
t_start = _time.time()
|
|
|
|
|
|
|
| 1292 |
|
| 1293 |
+
# Check if polarity+topic was already started in background by _show_raw_and_switch
|
| 1294 |
+
if thread_state and isinstance(thread_state, dict) and thread_state.get("thread"):
|
| 1295 |
+
bg_thread = thread_state["thread"]
|
| 1296 |
+
_result = thread_state["result"]
|
| 1297 |
+
sentence_lists = thread_state["sentence_lists"]
|
| 1298 |
+
active_texts = thread_state["active_texts"]
|
| 1299 |
+
all_sentences = thread_state["all_sentences"]
|
| 1300 |
|
| 1301 |
+
progress(0.30, desc="Predicting polarity and topics...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1302 |
|
| 1303 |
+
# Wait for the background thread (may already be done!)
|
| 1304 |
+
bg_thread.join()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1305 |
|
| 1306 |
+
if _result.get("error"):
|
| 1307 |
+
raise _result["error"]
|
| 1308 |
|
| 1309 |
+
polarity_map = _result["polarity"]
|
| 1310 |
+
topic_map = _result["topic"]
|
| 1311 |
+
print(f"[TIMING] Polarity+Topic (from early-start thread): {_time.time() - t_start:.1f}s wait")
|
| 1312 |
+
else:
|
| 1313 |
+
# Fallback: compute from scratch (e.g. if thread_state was not passed)
|
| 1314 |
+
all_texts = [text1, text2, text3, text4, text5, text6]
|
| 1315 |
+
active_texts = [t for t in all_texts if t and t.strip()]
|
| 1316 |
+
|
| 1317 |
+
if len(active_texts) < 2:
|
| 1318 |
+
raise ValueError("Please enter at least two reviews")
|
| 1319 |
+
|
| 1320 |
+
progress(0.0, desc="Loading models...")
|
| 1321 |
+
t0 = _time.time()
|
| 1322 |
+
processor = get_interactive_processor()
|
| 1323 |
+
print(f"[TIMING] get_interactive_processor: {_time.time() - t0:.1f}s")
|
| 1324 |
+
|
| 1325 |
+
progress(0.10, desc="Tokenizing reviews...")
|
| 1326 |
+
t0 = _time.time()
|
| 1327 |
+
sentence_lists = [[s for s in glimpse_tokenizer(t) if s.strip()] for t in active_texts]
|
| 1328 |
+
sentence_lists = [sl for sl in sentence_lists if sl]
|
| 1329 |
+
print(f"[TIMING] Tokenization: {_time.time() - t0:.1f}s ({sum(len(sl) for sl in sentence_lists)} total sentences)")
|
| 1330 |
+
|
| 1331 |
+
if len(sentence_lists) < 2:
|
| 1332 |
+
raise ValueError("At least two reviews must have valid sentences")
|
| 1333 |
+
|
| 1334 |
+
t0 = _time.time()
|
| 1335 |
+
all_sentences = filter_and_clean_sentences(
|
| 1336 |
+
list(set(s for sl in sentence_lists for s in sl))
|
| 1337 |
+
)
|
| 1338 |
+
print(f"[TIMING] filter_and_clean: {_time.time() - t0:.1f}s ({len(all_sentences)} unique sentences)")
|
| 1339 |
|
| 1340 |
+
progress(0.30, desc="Predicting polarity and topics...")
|
| 1341 |
+
t0 = _time.time()
|
| 1342 |
+
polarity_map = processor.predict_polarity(all_sentences)
|
| 1343 |
+
topic_map = processor.predict_topic(all_sentences)
|
| 1344 |
+
print(f"[TIMING] Polarity+Topic (sequential): {_time.time() - t0:.1f}s")
|
| 1345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1346 |
print(f"[TIMING] Fast processing total: {_time.time() - t_start:.1f}s")
|
| 1347 |
|
| 1348 |
# Step 5: Format results as HTML with collapsible review cards
|
|
|
|
| 2231 |
# State to hold raw rebuttal string (set by _show_raw_and_switch, consumed by process_interactive_reviews_fast)
|
| 2232 |
interactive_rebuttal_state = gr.State("")
|
| 2233 |
|
| 2234 |
+
# State to hold background processing thread (polarity+topic starts during page transition)
|
| 2235 |
+
processing_thread_state = gr.State(None)
|
| 2236 |
+
|
| 2237 |
+
_interactive_inputs = [review1_textbox, review2_textbox, review3_textbox, review4_textbox, review5_textbox, review6_textbox, focus_radio, interactive_rebuttal_state, processing_thread_state]
|
| 2238 |
|
| 2239 |
# State to hold RSA computation results for async updates
|
| 2240 |
rsa_computation_state = gr.State({})
|
|
|
|
| 2271 |
no_title_state = gr.State("")
|
| 2272 |
|
| 2273 |
def _show_raw_and_switch(r1, r2, r3, r4, r5, r6, rebuttal, title=""):
|
| 2274 |
+
"""Immediately switch to results view with raw tokenized reviews.
|
| 2275 |
+
Also kicks off polarity+topic in a background thread so processing
|
| 2276 |
+
overlaps with page transition rendering."""
|
| 2277 |
+
import time as _time
|
| 2278 |
from dependencies.Glimpse_tokenizer import glimpse_tokenizer
|
| 2279 |
texts = [r1, r2, r3, r4, r5, r6]
|
| 2280 |
active_count = sum(1 for t in texts if t and t.strip())
|
|
|
|
| 2314 |
)
|
| 2315 |
|
| 2316 |
title_text = title.strip() if title and title.strip() else ""
|
| 2317 |
+
|
| 2318 |
+
# Start polarity+topic in background thread NOW, so processing
|
| 2319 |
+
# overlaps with Gradio rendering the page transition to the user.
|
| 2320 |
+
# By the time process_interactive_reviews_fast runs, work is already underway.
|
| 2321 |
+
active_texts = [t for t in texts if t and t.strip()]
|
| 2322 |
+
sentence_lists = []
|
| 2323 |
+
for t in active_texts:
|
| 2324 |
+
sents = [s for s in glimpse_tokenizer(t) if s.strip()]
|
| 2325 |
+
if sents:
|
| 2326 |
+
sentence_lists.append(sents)
|
| 2327 |
+
all_sentences = filter_and_clean_sentences(
|
| 2328 |
+
list(set(s for sl in sentence_lists for s in sl))
|
| 2329 |
+
)
|
| 2330 |
+
|
| 2331 |
+
processor = get_interactive_processor()
|
| 2332 |
+
_thread_result = {"polarity": None, "topic": None, "error": None}
|
| 2333 |
+
|
| 2334 |
+
def _run_polarity_topic():
|
| 2335 |
+
try:
|
| 2336 |
+
t0 = _time.time()
|
| 2337 |
+
_thread_result["polarity"] = processor.predict_polarity(all_sentences)
|
| 2338 |
+
_thread_result["topic"] = processor.predict_topic(all_sentences)
|
| 2339 |
+
print(f"[TIMING] Early polarity+topic thread done in {_time.time() - t0:.1f}s")
|
| 2340 |
+
except Exception as e:
|
| 2341 |
+
_thread_result["error"] = e
|
| 2342 |
+
|
| 2343 |
+
bg_thread = threading.Thread(target=_run_polarity_topic, daemon=True)
|
| 2344 |
+
bg_thread.start()
|
| 2345 |
+
print(f"[TIMING] Background polarity+topic thread started (page transitioning...)")
|
| 2346 |
+
|
| 2347 |
+
thread_state = {
|
| 2348 |
+
"thread": bg_thread,
|
| 2349 |
+
"result": _thread_result,
|
| 2350 |
+
"sentence_lists": sentence_lists,
|
| 2351 |
+
"active_texts": active_texts,
|
| 2352 |
+
"all_sentences": all_sentences,
|
| 2353 |
+
}
|
| 2354 |
+
|
| 2355 |
return (
|
| 2356 |
*none_out, # none_text1..6
|
| 2357 |
gr.update(visible=False), # input_section
|
|
|
|
| 2369 |
active_count, # interactive_review_count
|
| 2370 |
rebuttal or "", # interactive_rebuttal_state
|
| 2371 |
gr.update(visible=False, value=""), # interactive_legend_html (reset on new submission)
|
| 2372 |
+
thread_state, # processing_thread_state
|
| 2373 |
)
|
| 2374 |
|
| 2375 |
def _show_results_with_rebuttal(rebuttal, active_count):
|
|
|
|
| 2440 |
rebuttal_for_review1, rebuttal_for_review2, rebuttal_for_review3,
|
| 2441 |
rebuttal_for_review4, rebuttal_for_review5, rebuttal_for_review6,
|
| 2442 |
interactive_rebuttal_display, interactive_review_count,
|
| 2443 |
+
interactive_rebuttal_state, interactive_legend_html,
|
| 2444 |
+
processing_thread_state]
|
| 2445 |
).success(
|
| 2446 |
fn=process_interactive_reviews_fast,
|
| 2447 |
inputs=_interactive_inputs,
|
|
|
|
| 2479 |
rebuttal_for_review1, rebuttal_for_review2, rebuttal_for_review3,
|
| 2480 |
rebuttal_for_review4, rebuttal_for_review5, rebuttal_for_review6,
|
| 2481 |
interactive_rebuttal_display, interactive_review_count,
|
| 2482 |
+
interactive_rebuttal_state, interactive_legend_html,
|
| 2483 |
+
processing_thread_state]
|
| 2484 |
).success(
|
| 2485 |
fn=process_interactive_reviews_fast,
|
| 2486 |
inputs=_interactive_inputs,
|
|
|
|
| 2648 |
outputs=_review_outputs,
|
| 2649 |
)
|
| 2650 |
|
| 2651 |
+
# Pre-load interactive processor models at startup so first request isn't slow
|
| 2652 |
+
get_interactive_processor()
|
| 2653 |
+
|
| 2654 |
demo.launch(share=False)
|
interface/interactive_processor.py
CHANGED
|
@@ -46,6 +46,19 @@ def _try_bettertransformer(model):
|
|
| 46 |
return model
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
class InteractiveReviewProcessor:
|
| 50 |
"""Process reviews through the same pipeline as preprocessed data."""
|
| 51 |
|
|
@@ -54,6 +67,9 @@ class InteractiveReviewProcessor:
|
|
| 54 |
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 55 |
t_total = time.time()
|
| 56 |
|
|
|
|
|
|
|
|
|
|
| 57 |
# Load summarization model (for RSA)
|
| 58 |
t0 = time.time()
|
| 59 |
rsa_model_name = "sshleifer/distilbart-cnn-12-3"
|
|
@@ -65,7 +81,7 @@ class InteractiveReviewProcessor:
|
|
| 65 |
)
|
| 66 |
self.rsa_tokenizer = AutoTokenizer.from_pretrained(rsa_model_name)
|
| 67 |
self.rsa_model.to(self.device)
|
| 68 |
-
|
| 69 |
self.rsa_model.eval()
|
| 70 |
print(f"[TIMING] RSA model loaded in {time.time() - t0:.1f}s")
|
| 71 |
|
|
@@ -85,7 +101,9 @@ class InteractiveReviewProcessor:
|
|
| 85 |
self.polarity_tokenizer = AutoTokenizer.from_pretrained(polarity_model_name)
|
| 86 |
self.polarity_model = AutoModelForSequenceClassification.from_pretrained(polarity_model_name)
|
| 87 |
self.polarity_model.to(self.device)
|
|
|
|
| 88 |
self.polarity_model.eval()
|
|
|
|
| 89 |
print(f"[TIMING] Polarity model loaded in {time.time() - t0:.1f}s")
|
| 90 |
|
| 91 |
# Load topic model
|
|
@@ -102,8 +120,11 @@ class InteractiveReviewProcessor:
|
|
| 102 |
self.topic_tokenizer = AutoTokenizer.from_pretrained(topic_model_name)
|
| 103 |
self.topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model_name)
|
| 104 |
self.topic_model.to(self.device)
|
|
|
|
| 105 |
self.topic_model.eval()
|
|
|
|
| 106 |
print(f"[TIMING] Topic model loaded in {time.time() - t0:.1f}s")
|
|
|
|
| 107 |
print(f"[TIMING] All models loaded in {time.time() - t_total:.1f}s")
|
| 108 |
|
| 109 |
# Topic ID to label mapping
|
|
|
|
| 46 |
return model
|
| 47 |
|
| 48 |
|
| 49 |
+
|
| 50 |
+
def _set_optimal_threads():
|
| 51 |
+
"""Set PyTorch thread count from SLURM allocation to avoid over/under-subscription."""
|
| 52 |
+
slurm_cpus = os.environ.get('SLURM_CPUS_PER_TASK') or os.environ.get('SLURM_CPUS_ON_NODE')
|
| 53 |
+
if slurm_cpus:
|
| 54 |
+
num_threads = int(slurm_cpus)
|
| 55 |
+
torch.set_num_threads(num_threads)
|
| 56 |
+
torch.set_num_interop_threads(min(num_threads, 4))
|
| 57 |
+
print(f"[THREADS] Set to {num_threads} (from SLURM)")
|
| 58 |
+
else:
|
| 59 |
+
print(f"[THREADS] Using PyTorch default: {torch.get_num_threads()}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
class InteractiveReviewProcessor:
|
| 63 |
"""Process reviews through the same pipeline as preprocessed data."""
|
| 64 |
|
|
|
|
| 67 |
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 68 |
t_total = time.time()
|
| 69 |
|
| 70 |
+
# Set optimal thread count for SLURM environment
|
| 71 |
+
_set_optimal_threads()
|
| 72 |
+
|
| 73 |
# Load summarization model (for RSA)
|
| 74 |
t0 = time.time()
|
| 75 |
rsa_model_name = "sshleifer/distilbart-cnn-12-3"
|
|
|
|
| 81 |
)
|
| 82 |
self.rsa_tokenizer = AutoTokenizer.from_pretrained(rsa_model_name)
|
| 83 |
self.rsa_model.to(self.device)
|
| 84 |
+
# BetterTransformer DISABLED for RSA — causes 2x slowdown on DistilBart CPU
|
| 85 |
self.rsa_model.eval()
|
| 86 |
print(f"[TIMING] RSA model loaded in {time.time() - t0:.1f}s")
|
| 87 |
|
|
|
|
| 101 |
self.polarity_tokenizer = AutoTokenizer.from_pretrained(polarity_model_name)
|
| 102 |
self.polarity_model = AutoModelForSequenceClassification.from_pretrained(polarity_model_name)
|
| 103 |
self.polarity_model.to(self.device)
|
| 104 |
+
self.polarity_model = _try_bettertransformer(self.polarity_model)
|
| 105 |
self.polarity_model.eval()
|
| 106 |
+
|
| 107 |
print(f"[TIMING] Polarity model loaded in {time.time() - t0:.1f}s")
|
| 108 |
|
| 109 |
# Load topic model
|
|
|
|
| 120 |
self.topic_tokenizer = AutoTokenizer.from_pretrained(topic_model_name)
|
| 121 |
self.topic_model = AutoModelForSequenceClassification.from_pretrained(topic_model_name)
|
| 122 |
self.topic_model.to(self.device)
|
| 123 |
+
self.topic_model = _try_bettertransformer(self.topic_model)
|
| 124 |
self.topic_model.eval()
|
| 125 |
+
|
| 126 |
print(f"[TIMING] Topic model loaded in {time.time() - t0:.1f}s")
|
| 127 |
+
|
| 128 |
print(f"[TIMING] All models loaded in {time.time() - t_total:.1f}s")
|
| 129 |
|
| 130 |
# Topic ID to label mapping
|