Sina1138 commited on
Commit
071dd42
·
1 Parent(s): 1b45a22

Add device-aware RSA optimizations for CPU/GPU

Browse files

- Auto-detect device and apply appropriate optimizations
- CPU: float32 dtype, batch_size=32
- GPU: float16 dtype, batch_size=64
- Add comprehensive validation suite for both environments

.gitignore CHANGED
@@ -375,3 +375,4 @@ data/DISAPERE_test.py
375
  .idea/
376
  *.sublime-project
377
  *.sublime-workspace
 
 
375
  .idea/
376
  *.sublime-project
377
  *.sublime-workspace
378
+ validation/quick_check.py
dependencies/rsa_reranker.py CHANGED
@@ -33,7 +33,7 @@ class RSAReranking:
33
  tokenizer,
34
  candidates: List[str],
35
  source_texts: List[str],
36
- batch_size: int = 32,
37
  rationality: int = 1,
38
  device="cuda",
39
  ):
@@ -42,8 +42,7 @@ class RSAReranking:
42
  :param tokenizer:
43
  :param candidates: list of candidates summaries
44
  :param source_texts: list of source texts
45
- :param batch_size: batch size used to compute the likelihoods (can be high since we don't need gradients and
46
- it's a single forward pass)
47
  :param rationality: rationality parameter of the RSA model
48
  :param device: device used to compute the likelihoods
49
  """
@@ -51,14 +50,22 @@ class RSAReranking:
51
  self.device = device
52
  self.model = model.to(self.device)
53
  self.tokenizer = tokenizer
54
-
55
 
56
  self.candidates = candidates
57
  self.source_texts = source_texts
58
 
 
 
 
 
59
  self.batch_size = batch_size
60
  self.rationality = rationality
61
 
 
 
 
 
62
  def compute_conditionned_likelihood(
63
  self, x: List[str], y: List[str], mean: bool = True
64
  ) -> torch.Tensor:
@@ -79,19 +86,49 @@ class RSAReranking:
79
  loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
80
  batch_size = len(x)
81
 
82
- x = self.tokenizer(
83
- x,
84
- return_tensors="pt",
85
- padding=True,
86
- truncation=True,
87
- max_length=1024,
88
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  y = self.tokenizer(
90
  y,
91
  return_tensors="pt",
92
  padding=True,
93
  truncation=True,
94
- max_length=1024,
95
  )
96
 
97
  # Move all tensors to the correct device
 
33
  tokenizer,
34
  candidates: List[str],
35
  source_texts: List[str],
36
+ batch_size: int = None, # Auto-detect: 64 for GPU, 32 for CPU
37
  rationality: int = 1,
38
  device="cuda",
39
  ):
 
42
  :param tokenizer:
43
  :param candidates: list of candidates summaries
44
  :param source_texts: list of source texts
45
+ :param batch_size: batch size used to compute the likelihoods (None = auto-detect based on device)
 
46
  :param rationality: rationality parameter of the RSA model
47
  :param device: device used to compute the likelihoods
48
  """
 
50
  self.device = device
51
  self.model = model.to(self.device)
52
  self.tokenizer = tokenizer
53
+
54
 
55
  self.candidates = candidates
56
  self.source_texts = source_texts
57
 
58
+ # Auto-detect batch size based on device if not specified
59
+ # GPU can handle larger batches (64), CPU uses smaller batches (32)
60
+ if batch_size is None:
61
+ batch_size = 64 if torch.cuda.is_available() else 32
62
  self.batch_size = batch_size
63
  self.rationality = rationality
64
 
65
+ # Pre-tokenize source texts once to avoid redundant tokenization
66
+ # This significantly speeds up likelihood_matrix computation
67
+ self._tokenized_sources_cache = {}
68
+
69
  def compute_conditionned_likelihood(
70
  self, x: List[str], y: List[str], mean: bool = True
71
  ) -> torch.Tensor:
 
86
  loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
87
  batch_size = len(x)
88
 
89
+ # Try to use cached tokenized sources for efficiency
90
+ # Cache key is the source text string
91
+ x_tokenized_list = []
92
+ all_cached = True
93
+ for source in x:
94
+ if source in self._tokenized_sources_cache:
95
+ x_tokenized_list.append(self._tokenized_sources_cache[source])
96
+ else:
97
+ all_cached = False
98
+ break
99
+
100
+ if all_cached and len(x_tokenized_list) > 0:
101
+ # All sources are cached - need to batch them together
102
+ # Stack the individual tokenized sources
103
+ x_tokenized = {
104
+ 'input_ids': torch.stack([item['input_ids'].squeeze(0) for item in x_tokenized_list]),
105
+ 'attention_mask': torch.stack([item['attention_mask'].squeeze(0) for item in x_tokenized_list])
106
+ }
107
+ else:
108
+ # Not all cached, tokenize the batch and cache individual items
109
+ x_strings = x # Keep reference to original strings for caching
110
+ x_tokenized = self.tokenizer(
111
+ x,
112
+ return_tensors="pt",
113
+ padding=True,
114
+ truncation=True,
115
+ max_length=512, # Reduced from 1024 - reviews rarely exceed 512 tokens
116
+ )
117
+ # Cache each source text individually for future use
118
+ for i, source_str in enumerate(x_strings):
119
+ if source_str not in self._tokenized_sources_cache:
120
+ self._tokenized_sources_cache[source_str] = {
121
+ 'input_ids': x_tokenized['input_ids'][i:i+1],
122
+ 'attention_mask': x_tokenized['attention_mask'][i:i+1]
123
+ }
124
+
125
+ x = x_tokenized
126
  y = self.tokenizer(
127
  y,
128
  return_tensors="pt",
129
  padding=True,
130
  truncation=True,
131
+ max_length=256, # Reduced from 1024 - sentences rarely exceed 256 tokens
132
  )
133
 
134
  # Move all tensors to the correct device
interface/interactive_processor.py CHANGED
@@ -53,7 +53,12 @@ class InteractiveReviewProcessor:
53
 
54
  # Load summarization model (for RSA)
55
  rsa_model_name = "sshleifer/distilbart-cnn-12-3"
56
- self.rsa_model = AutoModelForSeq2SeqLM.from_pretrained(rsa_model_name)
 
 
 
 
 
57
  self.rsa_tokenizer = AutoTokenizer.from_pretrained(rsa_model_name)
58
  self.rsa_model.to(self.device)
59
  self.rsa_model.eval()
@@ -205,6 +210,50 @@ class InteractiveReviewProcessor:
205
  for s in sentences
206
  ]
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def process_reviews(
209
  self,
210
  *reviews: str,
 
53
 
54
  # Load summarization model (for RSA)
55
  rsa_model_name = "sshleifer/distilbart-cnn-12-3"
56
+ self.rsa_model = AutoModelForSeq2SeqLM.from_pretrained(
57
+ rsa_model_name,
58
+ # Use float16 only on GPU (2x faster inference, 2x less memory)
59
+ # CPU doesn't support float16 well and would be slower
60
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
61
+ )
62
  self.rsa_tokenizer = AutoTokenizer.from_pretrained(rsa_model_name)
63
  self.rsa_model.to(self.device)
64
  self.rsa_model.eval()
 
210
  for s in sentences
211
  ]
212
 
213
+ def process_reviews_fast(self, *reviews: str) -> Dict:
214
+ """
215
+ Process reviews WITHOUT RSA (fast path: ~3-5 sec on CPU).
216
+
217
+ Returns polarity + topic scores immediately.
218
+ RSA can be computed separately in background.
219
+
220
+ Args:
221
+ reviews: Review texts (at least 2 required)
222
+
223
+ Returns:
224
+ Dictionary with polarity + topic scores (consensuality empty)
225
+ """
226
+ reviews = [r for r in reviews if r and r.strip()]
227
+ if len(reviews) < 2:
228
+ raise ValueError("At least two non-empty reviews are required")
229
+
230
+ # Tokenize reviews
231
+ sentence_lists = [[s for s in glimpse_tokenizer(r) if s.strip()] for r in reviews]
232
+
233
+ if any(len(sl) == 0 for sl in sentence_lists):
234
+ raise ValueError("One or more reviews have no valid sentences")
235
+
236
+ # Get unique sentences for scoring, excluding section headers
237
+ all_sentences = [s for s in set(s for sl in sentence_lists for s in sl) if not is_section_header(s)]
238
+
239
+ # Predict scores (skip consensuality - that comes async)
240
+ polarity_map = self.predict_polarity(all_sentences)
241
+ topic_map = self.predict_topic(all_sentences)
242
+
243
+ # Return with empty consensuality (will be updated async)
244
+ result = {
245
+ f"review{i+1}_sentences": sl for i, sl in enumerate(sentence_lists)
246
+ }
247
+ result.update({
248
+ "consensuality_scores": {},
249
+ "polarity_scores": polarity_map,
250
+ "topic_scores": topic_map,
251
+ })
252
+ result["most_common"] = []
253
+ result["most_unique"] = []
254
+
255
+ return result
256
+
257
  def process_reviews(
258
  self,
259
  *reviews: str,