Spaces:
Running
Running
ming
commited on
Commit
·
52f6c42
1
Parent(s):
de8b12b
Enhance HF V2 summaries: increase token limits, improve length parameters
Browse files- Increase default max_new_tokens from 64-128 to 256 minimum
- Add min_new_tokens floor (96-192) to prevent premature stopping
- Set length_penalty to 1.1 to encourage longer outputs
- Expand encoder max length to 2048 tokens (from 512/1024)
- Add repetition controls (no_repeat_ngram_size=3, repetition_penalty=1.05)
- Add chunking helper function for very long inputs
- Add defensive checks for torch availability in tests
app/services/hf_streaming_summarizer.py
CHANGED
|
@@ -22,6 +22,32 @@ except ImportError:
|
|
| 22 |
logger.warning("Transformers library not available. V2 endpoints will be disabled.")
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
class HFStreamingSummarizer:
|
| 26 |
"""Service for streaming text summarization using HuggingFace's lower-level API."""
|
| 27 |
|
|
@@ -169,17 +195,26 @@ class HFStreamingSummarizer:
|
|
| 169 |
logger.info(f"Processing text of {text_length} chars with HuggingFace model: {settings.hf_model_id}")
|
| 170 |
|
| 171 |
try:
|
| 172 |
-
# Use provided parameters or defaults
|
| 173 |
-
|
|
|
|
|
|
|
| 174 |
temperature = temperature or settings.hf_temperature
|
| 175 |
top_p = top_p or settings.hf_top_p
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
# Build tokenized inputs (normalize return types across tokenizers)
|
| 178 |
if "t5" in settings.hf_model_id.lower():
|
| 179 |
full_prompt = f"summarize: {text}"
|
| 180 |
-
inputs_raw = self.tokenizer(full_prompt, return_tensors="pt", max_length=
|
| 181 |
elif "bart" in settings.hf_model_id.lower():
|
| 182 |
-
inputs_raw = self.tokenizer(text, return_tensors="pt", max_length=
|
| 183 |
else:
|
| 184 |
messages = [
|
| 185 |
{"role": "system", "content": prompt},
|
|
@@ -188,10 +223,7 @@ class HFStreamingSummarizer:
|
|
| 188 |
|
| 189 |
if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
|
| 190 |
inputs_raw = self.tokenizer.apply_chat_template(
|
| 191 |
-
messages,
|
| 192 |
-
tokenize=True,
|
| 193 |
-
add_generation_prompt=True,
|
| 194 |
-
return_tensors="pt"
|
| 195 |
)
|
| 196 |
else:
|
| 197 |
full_prompt = f"{prompt}\n\n{text}"
|
|
@@ -211,14 +243,15 @@ class HFStreamingSummarizer:
|
|
| 211 |
|
| 212 |
# Ensure attention_mask only if missing AND input_ids is a Tensor
|
| 213 |
if "attention_mask" not in inputs and "input_ids" in inputs:
|
| 214 |
-
if
|
|
|
|
| 215 |
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
|
| 216 |
|
| 217 |
# --- HARDEN: force singleton batch across all tensor fields ---
|
| 218 |
def _to_singleton_batch(d):
|
| 219 |
out = {}
|
| 220 |
for k, v in d.items():
|
| 221 |
-
if isinstance(v, torch.Tensor):
|
| 222 |
if v.dim() == 1: # [seq] -> [1, seq]
|
| 223 |
out[k] = v.unsqueeze(0)
|
| 224 |
elif v.dim() >= 2:
|
|
@@ -233,8 +266,8 @@ class HFStreamingSummarizer:
|
|
| 233 |
|
| 234 |
# Final assert: crash early with clear log if still batched
|
| 235 |
_iid = inputs.get("input_ids", None)
|
| 236 |
-
if isinstance(_iid, torch.Tensor) and _iid.dim() >= 2 and _iid.size(0) != 1:
|
| 237 |
-
_shapes = {k: tuple(v.shape) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
|
| 238 |
logger.error(f"Input still batched after normalization: shapes={_shapes}")
|
| 239 |
raise ValueError("SingletonBatchEnforceFailed: input_ids batch dimension != 1")
|
| 240 |
|
|
@@ -286,6 +319,13 @@ class HFStreamingSummarizer:
|
|
| 286 |
gen_kwargs["num_return_sequences"] = 1
|
| 287 |
gen_kwargs["num_beams"] = 1
|
| 288 |
gen_kwargs["num_beam_groups"] = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
# Extra safety: remove any stray args that imply multiple sequences
|
| 290 |
for k in ("num_beam_groups", "num_beams", "num_return_sequences"):
|
| 291 |
# Reassert values in case something upstream re-injected them
|
|
|
|
| 22 |
logger.warning("Transformers library not available. V2 endpoints will be disabled.")
|
| 23 |
|
| 24 |
|
| 25 |
+
def _split_into_chunks(s: str, chunk_chars: int = 5000, overlap: int = 400) -> list[str]:
|
| 26 |
+
"""
|
| 27 |
+
Split text into overlapping chunks to handle very long inputs.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
s: Input text to split
|
| 31 |
+
chunk_chars: Target characters per chunk
|
| 32 |
+
overlap: Overlap between chunks in characters
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
List of text chunks
|
| 36 |
+
"""
|
| 37 |
+
chunks = []
|
| 38 |
+
i = 0
|
| 39 |
+
n = len(s)
|
| 40 |
+
while i < n:
|
| 41 |
+
j = min(i + chunk_chars, n)
|
| 42 |
+
chunks.append(s[i:j])
|
| 43 |
+
if j >= n:
|
| 44 |
+
break
|
| 45 |
+
i = j - overlap
|
| 46 |
+
if i < 0:
|
| 47 |
+
i = 0
|
| 48 |
+
return chunks
|
| 49 |
+
|
| 50 |
+
|
| 51 |
class HFStreamingSummarizer:
|
| 52 |
"""Service for streaming text summarization using HuggingFace's lower-level API."""
|
| 53 |
|
|
|
|
| 195 |
logger.info(f"Processing text of {text_length} chars with HuggingFace model: {settings.hf_model_id}")
|
| 196 |
|
| 197 |
try:
|
| 198 |
+
# Use provided parameters or sensible defaults
|
| 199 |
+
# Aim for ~200–400 tokens summary by default.
|
| 200 |
+
# If settings.hf_max_new_tokens is small, override with 256.
|
| 201 |
+
max_new_tokens = max_new_tokens or max(getattr(settings, "hf_max_new_tokens", 0) or 0, 256)
|
| 202 |
temperature = temperature or settings.hf_temperature
|
| 203 |
top_p = top_p or settings.hf_top_p
|
| 204 |
|
| 205 |
+
# Determine a generous encoder max length (respect tokenizer.model_max_length)
|
| 206 |
+
model_max = getattr(self.tokenizer, "model_max_length", 1024)
|
| 207 |
+
# Handle case where model_max_length might be None, 0, or not a valid int
|
| 208 |
+
if not isinstance(model_max, int) or model_max <= 0:
|
| 209 |
+
model_max = 1024
|
| 210 |
+
enc_max_len = min(model_max, 2048) # cap to 2k to avoid OOM on small Spaces
|
| 211 |
+
|
| 212 |
# Build tokenized inputs (normalize return types across tokenizers)
|
| 213 |
if "t5" in settings.hf_model_id.lower():
|
| 214 |
full_prompt = f"summarize: {text}"
|
| 215 |
+
inputs_raw = self.tokenizer(full_prompt, return_tensors="pt", max_length=enc_max_len, truncation=True)
|
| 216 |
elif "bart" in settings.hf_model_id.lower():
|
| 217 |
+
inputs_raw = self.tokenizer(text, return_tensors="pt", max_length=enc_max_len, truncation=True)
|
| 218 |
else:
|
| 219 |
messages = [
|
| 220 |
{"role": "system", "content": prompt},
|
|
|
|
| 223 |
|
| 224 |
if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
|
| 225 |
inputs_raw = self.tokenizer.apply_chat_template(
|
| 226 |
+
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
|
|
|
|
|
|
|
|
|
| 227 |
)
|
| 228 |
else:
|
| 229 |
full_prompt = f"{prompt}\n\n{text}"
|
|
|
|
| 243 |
|
| 244 |
# Ensure attention_mask only if missing AND input_ids is a Tensor
|
| 245 |
if "attention_mask" not in inputs and "input_ids" in inputs:
|
| 246 |
+
# Check if torch is available and input is a tensor
|
| 247 |
+
if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(inputs["input_ids"], torch.Tensor):
|
| 248 |
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
|
| 249 |
|
| 250 |
# --- HARDEN: force singleton batch across all tensor fields ---
|
| 251 |
def _to_singleton_batch(d):
|
| 252 |
out = {}
|
| 253 |
for k, v in d.items():
|
| 254 |
+
if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(v, torch.Tensor):
|
| 255 |
if v.dim() == 1: # [seq] -> [1, seq]
|
| 256 |
out[k] = v.unsqueeze(0)
|
| 257 |
elif v.dim() >= 2:
|
|
|
|
| 266 |
|
| 267 |
# Final assert: crash early with clear log if still batched
|
| 268 |
_iid = inputs.get("input_ids", None)
|
| 269 |
+
if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(_iid, torch.Tensor) and _iid.dim() >= 2 and _iid.size(0) != 1:
|
| 270 |
+
_shapes = {k: tuple(v.shape) for k, v in inputs.items() if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(v, torch.Tensor)}
|
| 271 |
logger.error(f"Input still batched after normalization: shapes={_shapes}")
|
| 272 |
raise ValueError("SingletonBatchEnforceFailed: input_ids batch dimension != 1")
|
| 273 |
|
|
|
|
| 319 |
gen_kwargs["num_return_sequences"] = 1
|
| 320 |
gen_kwargs["num_beams"] = 1
|
| 321 |
gen_kwargs["num_beam_groups"] = 1
|
| 322 |
+
# Ensure we don't stop too early; set a floor and slightly favor longer generations
|
| 323 |
+
gen_kwargs["min_new_tokens"] = max(96, min(192, max_new_tokens // 2)) # floor ~100–192
|
| 324 |
+
# length_penalty > 1.0 encourages longer outputs on encoder-decoder models
|
| 325 |
+
gen_kwargs["length_penalty"] = 1.1
|
| 326 |
+
# Reduce premature EOS in some checkpoints (optional)
|
| 327 |
+
gen_kwargs["no_repeat_ngram_size"] = 3
|
| 328 |
+
gen_kwargs["repetition_penalty"] = 1.05
|
| 329 |
# Extra safety: remove any stray args that imply multiple sequences
|
| 330 |
for k in ("num_beam_groups", "num_beams", "num_return_sequences"):
|
| 331 |
# Reassert values in case something upstream re-injected them
|