Azidan commited on
Commit
430cffe
ยท
verified ยท
1 Parent(s): fdc9079

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -32
app.py CHANGED
@@ -8,19 +8,16 @@ import torch
8
  # =========================
9
  # Model setup (CPU-safe, Multi-language)
10
  # =========================
11
- # Use different models for English and Arabic
12
- EN_SUMMARIZER_MODEL = "sshleifer/distilbart-cnn-12-6" # English summarization
13
  AR_SUMMARIZER_MODEL = "csebuetnlp/mT5_multilingual_XLSum" # Multilingual (includes Arabic)
14
  QA_MODEL = "google/flan-t5-small" # Question generation
15
 
16
  print("Loading models... This may take a minute on first run.")
17
 
18
- # English summarizer
19
- en_summarizer = pipeline(
20
- "summarization",
21
- model=EN_SUMMARIZER_MODEL,
22
- device=-1 # CPU only
23
- )
24
 
25
  # Multilingual summarizer (for Arabic and other languages)
26
  ar_tokenizer = AutoTokenizer.from_pretrained(AR_SUMMARIZER_MODEL)
@@ -33,7 +30,7 @@ question_generator = pipeline(
33
  device=-1 # CPU only
34
  )
35
 
36
- CHUNK_SIZE = 512 # Conservative chunk size
37
 
38
  # =========================
39
  # Language Detection
@@ -68,10 +65,7 @@ def clean_text(text: str) -> str:
68
  def chunk_text(text: str, language: str):
69
  """Token-aware chunking to avoid model overflow."""
70
  # Use appropriate tokenizer based on language
71
- if language == "ar_AR":
72
- tokenizer = ar_tokenizer
73
- else:
74
- tokenizer = AutoTokenizer.from_pretrained(EN_SUMMARIZER_MODEL)
75
 
76
  tokens = tokenizer.encode(text, add_special_tokens=False)
77
  chunks = []
@@ -180,17 +174,17 @@ def summarize_long_text(text: str, summary_length: str, language: str, progress=
180
  if not text or len(text.strip()) == 0:
181
  return "No text provided." if language == "en_XX" else "ู„ู… ูŠุชู… ุชู‚ุฏูŠู… ู†ุต."
182
 
183
- # Length mapping
184
  length_map = {
185
- "Short (25%)": {"max": 150, "min": 40},
186
- "Medium (50%)": {"max": 250, "min": 80},
187
- "Long (75%)": {"max": 400, "min": 120},
188
- "ู‚ุตูŠุฑ (25%)": {"max": 150, "min": 40},
189
- "ู…ุชูˆุณุท (50%)": {"max": 250, "min": 80},
190
- "ุทูˆูŠู„ (75%)": {"max": 400, "min": 120}
191
  }
192
 
193
- length_params = length_map.get(summary_length, {"max": 250, "min": 80})
194
 
195
  progress(0, desc="Extracting headings...")
196
  headings_section = extract_possible_headings(text)
@@ -205,7 +199,7 @@ def summarize_long_text(text: str, summary_length: str, language: str, progress=
205
  chunk = chunks[i]
206
  try:
207
  if language == "ar_AR":
208
- # Use mT5 for Arabic
209
  inputs = ar_tokenizer(chunk, return_tensors="pt", max_length=512, truncation=True)
210
  summary_ids = ar_model.generate(
211
  inputs["input_ids"],
@@ -217,26 +211,34 @@ def summarize_long_text(text: str, summary_length: str, language: str, progress=
217
  )
218
  summary = ar_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
219
  else:
220
- # Use distilbart for English
221
- summary = en_summarizer(
222
- chunk,
 
 
223
  max_length=length_params["max"],
224
  min_length=length_params["min"],
225
- do_sample=False
226
- )[0]["summary_text"]
 
 
227
 
228
  cleaned = clean_text(summary)
229
- chunk_label = f"**Chunk {i+1}:**" if language == "en_XX" else f"**ุงู„ุฌุฒุก {i+1}:**"
230
- summaries.append(f"{chunk_label} {cleaned}")
 
231
  except Exception as e:
232
  print(f"Error in chunk {i}: {str(e)}")
233
- pass # skip problematic chunks
234
 
235
  # Format summaries
236
  header = "### ๐Ÿ“ Detailed Summary\n\n" if language == "en_XX" else "### ๐Ÿ“ ู…ู„ุฎุต ุชูุตูŠู„ูŠ\n\n"
237
  summary_md = header
238
- for s in summaries:
239
- summary_md += f"- {s}\n"
 
 
 
240
 
241
  progress(0.8, desc="Generating questions...")
242
  questions = generate_questions(summary_md, language)
 
8
  # =========================
9
  # Model setup (CPU-safe, Multi-language)
10
  # =========================
11
+ # Use T5-based models that support text2text-generation
12
+ EN_SUMMARIZER_MODEL = "google/flan-t5-base" # English - works with text2text
13
  AR_SUMMARIZER_MODEL = "csebuetnlp/mT5_multilingual_XLSum" # Multilingual (includes Arabic)
14
  QA_MODEL = "google/flan-t5-small" # Question generation
15
 
16
  print("Loading models... This may take a minute on first run.")
17
 
18
+ # English summarizer using text2text-generation
19
+ en_tokenizer = AutoTokenizer.from_pretrained(EN_SUMMARIZER_MODEL)
20
+ en_model = AutoModelForSeq2SeqLM.from_pretrained(EN_SUMMARIZER_MODEL)
 
 
 
21
 
22
  # Multilingual summarizer (for Arabic and other languages)
23
  ar_tokenizer = AutoTokenizer.from_pretrained(AR_SUMMARIZER_MODEL)
 
30
  device=-1 # CPU only
31
  )
32
 
33
+ CHUNK_SIZE = 400 # Conservative chunk size for T5 models
34
 
35
  # =========================
36
  # Language Detection
 
65
  def chunk_text(text: str, language: str):
66
  """Token-aware chunking to avoid model overflow."""
67
  # Use appropriate tokenizer based on language
68
+ tokenizer = ar_tokenizer if language == "ar_AR" else en_tokenizer
 
 
 
69
 
70
  tokens = tokenizer.encode(text, add_special_tokens=False)
71
  chunks = []
 
174
  if not text or len(text.strip()) == 0:
175
  return "No text provided." if language == "en_XX" else "ู„ู… ูŠุชู… ุชู‚ุฏูŠู… ู†ุต."
176
 
177
+ # Length mapping (for T5 models, these are approximate)
178
  length_map = {
179
+ "Short (25%)": {"max": 128, "min": 30},
180
+ "Medium (50%)": {"max": 256, "min": 60},
181
+ "Long (75%)": {"max": 400, "min": 100},
182
+ "ู‚ุตูŠุฑ (25%)": {"max": 128, "min": 30},
183
+ "ู…ุชูˆุณุท (50%)": {"max": 256, "min": 60},
184
+ "ุทูˆูŠู„ (75%)": {"max": 400, "min": 100}
185
  }
186
 
187
+ length_params = length_map.get(summary_length, {"max": 256, "min": 60})
188
 
189
  progress(0, desc="Extracting headings...")
190
  headings_section = extract_possible_headings(text)
 
199
  chunk = chunks[i]
200
  try:
201
  if language == "ar_AR":
202
+ # Use mT5 for Arabic with direct model inference
203
  inputs = ar_tokenizer(chunk, return_tensors="pt", max_length=512, truncation=True)
204
  summary_ids = ar_model.generate(
205
  inputs["input_ids"],
 
211
  )
212
  summary = ar_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
213
  else:
214
+ # Use FLAN-T5 for English with summarization prompt
215
+ prompt = f"Summarize the following text in detail:\n\n{chunk}"
216
+ inputs = en_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
217
+ summary_ids = en_model.generate(
218
+ inputs["input_ids"],
219
  max_length=length_params["max"],
220
  min_length=length_params["min"],
221
+ num_beams=4,
222
+ early_stopping=True
223
+ )
224
+ summary = en_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
225
 
226
  cleaned = clean_text(summary)
227
+ if cleaned: # Only add non-empty summaries
228
+ chunk_label = f"**Chunk {i+1}:**" if language == "en_XX" else f"**ุงู„ุฌุฒุก {i+1}:**"
229
+ summaries.append(f"{chunk_label} {cleaned}")
230
  except Exception as e:
231
  print(f"Error in chunk {i}: {str(e)}")
232
+ continue # skip problematic chunks
233
 
234
  # Format summaries
235
  header = "### ๐Ÿ“ Detailed Summary\n\n" if language == "en_XX" else "### ๐Ÿ“ ู…ู„ุฎุต ุชูุตูŠู„ูŠ\n\n"
236
  summary_md = header
237
+ if summaries:
238
+ for s in summaries:
239
+ summary_md += f"- {s}\n"
240
+ else:
241
+ summary_md += "Unable to generate summary. Please try with different text.\n"
242
 
243
  progress(0.8, desc="Generating questions...")
244
  questions = generate_questions(summary_md, language)