ming commited on
Commit
52f6c42
·
1 Parent(s): de8b12b

Enhance HF V2 summaries: increase token limits, improve length parameters

Browse files

- Increase default max_new_tokens from 64-128 to 256 minimum
- Add min_new_tokens floor (96-192) to prevent premature stopping
- Set length_penalty to 1.1 to encourage longer outputs
- Expand encoder max length to 2048 tokens (from 512/1024)
- Add repetition controls (no_repeat_ngram_size=3, repetition_penalty=1.05)
- Add chunking helper function for very long inputs
- Add defensive checks for torch availability in tests

app/services/hf_streaming_summarizer.py CHANGED
@@ -22,6 +22,32 @@ except ImportError:
22
  logger.warning("Transformers library not available. V2 endpoints will be disabled.")
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class HFStreamingSummarizer:
26
  """Service for streaming text summarization using HuggingFace's lower-level API."""
27
 
@@ -169,17 +195,26 @@ class HFStreamingSummarizer:
169
  logger.info(f"Processing text of {text_length} chars with HuggingFace model: {settings.hf_model_id}")
170
 
171
  try:
172
- # Use provided parameters or defaults
173
- max_new_tokens = max_new_tokens or settings.hf_max_new_tokens
 
 
174
  temperature = temperature or settings.hf_temperature
175
  top_p = top_p or settings.hf_top_p
176
 
 
 
 
 
 
 
 
177
  # Build tokenized inputs (normalize return types across tokenizers)
178
  if "t5" in settings.hf_model_id.lower():
179
  full_prompt = f"summarize: {text}"
180
- inputs_raw = self.tokenizer(full_prompt, return_tensors="pt", max_length=512, truncation=True)
181
  elif "bart" in settings.hf_model_id.lower():
182
- inputs_raw = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
183
  else:
184
  messages = [
185
  {"role": "system", "content": prompt},
@@ -188,10 +223,7 @@ class HFStreamingSummarizer:
188
 
189
  if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
190
  inputs_raw = self.tokenizer.apply_chat_template(
191
- messages,
192
- tokenize=True,
193
- add_generation_prompt=True,
194
- return_tensors="pt"
195
  )
196
  else:
197
  full_prompt = f"{prompt}\n\n{text}"
@@ -211,14 +243,15 @@ class HFStreamingSummarizer:
211
 
212
  # Ensure attention_mask only if missing AND input_ids is a Tensor
213
  if "attention_mask" not in inputs and "input_ids" in inputs:
214
- if isinstance(inputs["input_ids"], torch.Tensor):
 
215
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
216
 
217
  # --- HARDEN: force singleton batch across all tensor fields ---
218
  def _to_singleton_batch(d):
219
  out = {}
220
  for k, v in d.items():
221
- if isinstance(v, torch.Tensor):
222
  if v.dim() == 1: # [seq] -> [1, seq]
223
  out[k] = v.unsqueeze(0)
224
  elif v.dim() >= 2:
@@ -233,8 +266,8 @@ class HFStreamingSummarizer:
233
 
234
  # Final assert: crash early with clear log if still batched
235
  _iid = inputs.get("input_ids", None)
236
- if isinstance(_iid, torch.Tensor) and _iid.dim() >= 2 and _iid.size(0) != 1:
237
- _shapes = {k: tuple(v.shape) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
238
  logger.error(f"Input still batched after normalization: shapes={_shapes}")
239
  raise ValueError("SingletonBatchEnforceFailed: input_ids batch dimension != 1")
240
 
@@ -286,6 +319,13 @@ class HFStreamingSummarizer:
286
  gen_kwargs["num_return_sequences"] = 1
287
  gen_kwargs["num_beams"] = 1
288
  gen_kwargs["num_beam_groups"] = 1
 
 
 
 
 
 
 
289
  # Extra safety: remove any stray args that imply multiple sequences
290
  for k in ("num_beam_groups", "num_beams", "num_return_sequences"):
291
  # Reassert values in case something upstream re-injected them
 
22
  logger.warning("Transformers library not available. V2 endpoints will be disabled.")
23
 
24
 
25
+ def _split_into_chunks(s: str, chunk_chars: int = 5000, overlap: int = 400) -> list[str]:
26
+ """
27
+ Split text into overlapping chunks to handle very long inputs.
28
+
29
+ Args:
30
+ s: Input text to split
31
+ chunk_chars: Target characters per chunk
32
+ overlap: Overlap between chunks in characters
33
+
34
+ Returns:
35
+ List of text chunks
36
+ """
37
+ chunks = []
38
+ i = 0
39
+ n = len(s)
40
+ while i < n:
41
+ j = min(i + chunk_chars, n)
42
+ chunks.append(s[i:j])
43
+ if j >= n:
44
+ break
45
+ i = j - overlap
46
+ if i < 0:
47
+ i = 0
48
+ return chunks
49
+
50
+
51
  class HFStreamingSummarizer:
52
  """Service for streaming text summarization using HuggingFace's lower-level API."""
53
 
 
195
  logger.info(f"Processing text of {text_length} chars with HuggingFace model: {settings.hf_model_id}")
196
 
197
  try:
198
+ # Use provided parameters or sensible defaults
199
+ # Aim for ~200–400 tokens summary by default.
200
+ # If settings.hf_max_new_tokens is small, override with 256.
201
+ max_new_tokens = max_new_tokens or max(getattr(settings, "hf_max_new_tokens", 0) or 0, 256)
202
  temperature = temperature or settings.hf_temperature
203
  top_p = top_p or settings.hf_top_p
204
 
205
+ # Determine a generous encoder max length (respect tokenizer.model_max_length)
206
+ model_max = getattr(self.tokenizer, "model_max_length", 1024)
207
+ # Handle case where model_max_length might be None, 0, or not a valid int
208
+ if not isinstance(model_max, int) or model_max <= 0:
209
+ model_max = 1024
210
+ enc_max_len = min(model_max, 2048) # cap to 2k to avoid OOM on small Spaces
211
+
212
  # Build tokenized inputs (normalize return types across tokenizers)
213
  if "t5" in settings.hf_model_id.lower():
214
  full_prompt = f"summarize: {text}"
215
+ inputs_raw = self.tokenizer(full_prompt, return_tensors="pt", max_length=enc_max_len, truncation=True)
216
  elif "bart" in settings.hf_model_id.lower():
217
+ inputs_raw = self.tokenizer(text, return_tensors="pt", max_length=enc_max_len, truncation=True)
218
  else:
219
  messages = [
220
  {"role": "system", "content": prompt},
 
223
 
224
  if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
225
  inputs_raw = self.tokenizer.apply_chat_template(
226
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
 
 
 
227
  )
228
  else:
229
  full_prompt = f"{prompt}\n\n{text}"
 
243
 
244
  # Ensure attention_mask only if missing AND input_ids is a Tensor
245
  if "attention_mask" not in inputs and "input_ids" in inputs:
246
+ # Check if torch is available and input is a tensor
247
+ if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(inputs["input_ids"], torch.Tensor):
248
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
249
 
250
  # --- HARDEN: force singleton batch across all tensor fields ---
251
  def _to_singleton_batch(d):
252
  out = {}
253
  for k, v in d.items():
254
+ if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(v, torch.Tensor):
255
  if v.dim() == 1: # [seq] -> [1, seq]
256
  out[k] = v.unsqueeze(0)
257
  elif v.dim() >= 2:
 
266
 
267
  # Final assert: crash early with clear log if still batched
268
  _iid = inputs.get("input_ids", None)
269
+ if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(_iid, torch.Tensor) and _iid.dim() >= 2 and _iid.size(0) != 1:
270
+ _shapes = {k: tuple(v.shape) for k, v in inputs.items() if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(v, torch.Tensor)}
271
  logger.error(f"Input still batched after normalization: shapes={_shapes}")
272
  raise ValueError("SingletonBatchEnforceFailed: input_ids batch dimension != 1")
273
 
 
319
  gen_kwargs["num_return_sequences"] = 1
320
  gen_kwargs["num_beams"] = 1
321
  gen_kwargs["num_beam_groups"] = 1
322
+ # Ensure we don't stop too early; set a floor and slightly favor longer generations
323
+ gen_kwargs["min_new_tokens"] = max(96, min(192, max_new_tokens // 2)) # floor ~100–192
324
+ # length_penalty > 1.0 encourages longer outputs on encoder-decoder models
325
+ gen_kwargs["length_penalty"] = 1.1
326
+ # Reduce premature EOS in some checkpoints (optional)
327
+ gen_kwargs["no_repeat_ngram_size"] = 3
328
+ gen_kwargs["repetition_penalty"] = 1.05
329
  # Extra safety: remove any stray args that imply multiple sequences
330
  for k in ("num_beam_groups", "num_beams", "num_return_sequences"):
331
  # Reassert values in case something upstream re-injected them