the-carnage commited on
Commit
216c20d
Β·
1 Parent(s): bd0a950

Fix repetition: return original text for very short inputs, use greedy decoding, enforce min < max length

Browse files
Files changed (1) hide show
  1. app.py +33 -4
app.py CHANGED
@@ -42,10 +42,39 @@ def summarize_text(text, min_Len, max_Len):
42
  input_text = "summarize: " + text[:4000]
43
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
44
  input_token_count = inputs.input_ids.shape[1]
45
- # Cap lengths to avoid repetition when input is shorter than requested output
46
- effective_max = min(max_Len, max(input_token_count - 1, 10))
47
- effective_min = min(min_Len, max(effective_max // 2, 5))
48
- summary_ids = model.generate(inputs.input_ids, max_length=effective_max, min_length=effective_min, length_penalty=2.0, num_beams=4, early_stopping=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
50
 
51
  tab1, tab2, tab3 = st.tabs(["πŸ“ Text", "πŸ–ΌοΈ Image", "πŸ“„ PDF"])
 
42
  input_text = "summarize: " + text[:4000]
43
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
44
  input_token_count = inputs.input_ids.shape[1]
45
+
46
+ # For very short inputs, just return the original text
47
+ if input_token_count < 15:
48
+ return text.strip()
49
+
50
+ # Cap lengths to avoid repetition - max should not exceed input length
51
+ effective_max = min(max_Len, max(int(input_token_count * 0.6), 20))
52
+ effective_min = 5 # Minimum 5 tokens for a summary
53
+
54
+ # Ensure min < max
55
+ if effective_min >= effective_max:
56
+ effective_min = max(1, effective_max - 5)
57
+
58
+ # Use simpler generation for short inputs
59
+ if input_token_count < 50:
60
+ summary_ids = model.generate(
61
+ inputs.input_ids,
62
+ max_length=effective_max,
63
+ min_length=effective_min,
64
+ do_sample=False, # Deterministic
65
+ num_beams=1, # No beam search for short inputs
66
+ early_stopping=True
67
+ )
68
+ else:
69
+ summary_ids = model.generate(
70
+ inputs.input_ids,
71
+ max_length=effective_max,
72
+ min_length=effective_min,
73
+ length_penalty=2.0,
74
+ num_beams=4,
75
+ early_stopping=True
76
+ )
77
+
78
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
79
 
80
  tab1, tab2, tab3 = st.tabs(["πŸ“ Text", "πŸ–ΌοΈ Image", "πŸ“„ PDF"])