Sina1138 commited on
Commit
269e282
·
1 Parent(s): fefa128

Add @spaces.GPU decorated functions for ZeroGPU compatibility

Browse files
Files changed (1) hide show
  1. interface/Demo.py +19 -2
interface/Demo.py CHANGED
@@ -1098,6 +1098,24 @@ def get_interactive_processor():
1098
  return _interactive_processor
1099
 
1100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1101
  MAX_INTERACTIVE_REVIEWS = 6
1102
 
1103
 
@@ -1352,8 +1370,7 @@ def process_interactive_reviews_fast(text1: str, text2: str, text3: str, text4:
1352
 
1353
  progress(0.30, desc="Predicting polarity and topics...")
1354
  t0 = _time.time()
1355
- polarity_map = processor.predict_polarity(all_sentences)
1356
- topic_map = processor.predict_topic(all_sentences)
1357
  print(f"[TIMING] Polarity+Topic (sequential): {_time.time() - t0:.1f}s")
1358
 
1359
  print(f"[TIMING] Fast processing total: {_time.time() - t_start:.1f}s")
 
1098
  return _interactive_processor
1099
 
1100
 
1101
+ @_gpu
1102
+ def _gpu_predict_polarity_topic(sentences: List[str]) -> Tuple[Dict, Dict]:
1103
+ """Run polarity + topic inference on GPU. Decorated with @spaces.GPU for ZeroGPU."""
1104
+ processor = get_interactive_processor()
1105
+ processor.ensure_device()
1106
+ polarity_map = processor.predict_polarity(sentences)
1107
+ topic_map = processor.predict_topic(sentences)
1108
+ return polarity_map, topic_map
1109
+
1110
+
1111
+ @_gpu
1112
+ def _gpu_predict_rsa(active_texts: List[str], progress_callback=None) -> Dict:
1113
+ """Run RSA inference on GPU. Decorated with @spaces.GPU for ZeroGPU."""
1114
+ processor = get_interactive_processor()
1115
+ processor.ensure_device()
1116
+ return processor.predict_rsa_full(*active_texts, progress_callback=progress_callback)
1117
+
1118
+
1119
  MAX_INTERACTIVE_REVIEWS = 6
1120
 
1121
 
 
1370
 
1371
  progress(0.30, desc="Predicting polarity and topics...")
1372
  t0 = _time.time()
1373
+ polarity_map, topic_map = _gpu_predict_polarity_topic(all_sentences)
 
1374
  print(f"[TIMING] Polarity+Topic (sequential): {_time.time() - t0:.1f}s")
1375
 
1376
  print(f"[TIMING] Fast processing total: {_time.time() - t_start:.1f}s")