Azidan commited on
Commit
fdc9079
·
verified ·
1 Parent(s): 55b22ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -26
app.py CHANGED
@@ -8,22 +8,24 @@ import torch
8
  # =========================
9
  # Model setup (CPU-safe, Multi-language)
10
  # =========================
11
- # Use mBART for multilingual support (English + Arabic)
12
- SUMMARIZER_MODEL = "facebook/mbart-large-50-many-to-many-mmt"
13
- QA_MODEL = "google/flan-t5-base" # Better for question generation
 
14
 
15
  print("Loading models... This may take a minute on first run.")
16
 
17
- # Summarizer with mBART (supports Arabic)
18
- summarizer_tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL)
19
- summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL)
20
- summarizer = pipeline(
21
  "summarization",
22
- model=summarizer_model,
23
- tokenizer=summarizer_tokenizer,
24
  device=-1 # CPU only
25
  )
26
 
 
 
 
 
27
  # Question generator
28
  question_generator = pipeline(
29
  "text2text-generation",
@@ -31,7 +33,7 @@ question_generator = pipeline(
31
  device=-1 # CPU only
32
  )
33
 
34
- CHUNK_SIZE = 512 # Conservative for mBART
35
 
36
  # =========================
37
  # Language Detection
@@ -63,8 +65,14 @@ def clean_text(text: str) -> str:
63
  result.append(s.strip())
64
  return " ".join(result)
65
 
66
- def chunk_text(text: str, tokenizer):
67
  """Token-aware chunking to avoid model overflow."""
 
 
 
 
 
 
68
  tokens = tokenizer.encode(text, add_special_tokens=False)
69
  chunks = []
70
  for i in range(0, len(tokens), CHUNK_SIZE):
@@ -188,28 +196,34 @@ def summarize_long_text(text: str, summary_length: str, language: str, progress=
188
  headings_section = extract_possible_headings(text)
189
 
190
  progress(0.1, desc="Chunking text...")
191
- chunks = chunk_text(text, summarizer_tokenizer)
192
 
193
  summaries = []
194
  progress(0.2, desc="Summarizing chunks...")
195
 
196
- # Set language tokens for mBART
197
- src_lang = language
198
- tgt_lang = language
199
-
200
  for i in progress.tqdm(range(len(chunks))):
201
  chunk = chunks[i]
202
  try:
203
- # For mBART, we need to set source and target language
204
- summarizer_tokenizer.src_lang = src_lang
205
-
206
- summary = summarizer(
207
- chunk,
208
- max_length=length_params["max"],
209
- min_length=length_params["min"],
210
- do_sample=False,
211
- forced_bos_token_id=summarizer_tokenizer.lang_code_to_id[tgt_lang]
212
- )[0]["summary_text"]
 
 
 
 
 
 
 
 
 
 
213
 
214
  cleaned = clean_text(summary)
215
  chunk_label = f"**Chunk {i+1}:**" if language == "en_XX" else f"**الجزء {i+1}:**"
 
8
  # =========================
9
  # Model setup (CPU-safe, Multi-language)
10
  # =========================
11
+ # Use different models for English and Arabic
12
+ EN_SUMMARIZER_MODEL = "sshleifer/distilbart-cnn-12-6" # English summarization
13
+ AR_SUMMARIZER_MODEL = "csebuetnlp/mT5_multilingual_XLSum" # Multilingual (includes Arabic)
14
+ QA_MODEL = "google/flan-t5-small" # Question generation
15
 
16
  print("Loading models... This may take a minute on first run.")
17
 
18
+ # English summarizer
19
+ en_summarizer = pipeline(
 
 
20
  "summarization",
21
+ model=EN_SUMMARIZER_MODEL,
 
22
  device=-1 # CPU only
23
  )
24
 
25
+ # Multilingual summarizer (for Arabic and other languages)
26
+ ar_tokenizer = AutoTokenizer.from_pretrained(AR_SUMMARIZER_MODEL)
27
+ ar_model = AutoModelForSeq2SeqLM.from_pretrained(AR_SUMMARIZER_MODEL)
28
+
29
  # Question generator
30
  question_generator = pipeline(
31
  "text2text-generation",
 
33
  device=-1 # CPU only
34
  )
35
 
36
+ CHUNK_SIZE = 512 # Conservative chunk size
37
 
38
  # =========================
39
  # Language Detection
 
65
  result.append(s.strip())
66
  return " ".join(result)
67
 
68
+ def chunk_text(text: str, language: str):
69
  """Token-aware chunking to avoid model overflow."""
70
+ # Use appropriate tokenizer based on language
71
+ if language == "ar_AR":
72
+ tokenizer = ar_tokenizer
73
+ else:
74
+ tokenizer = AutoTokenizer.from_pretrained(EN_SUMMARIZER_MODEL)
75
+
76
  tokens = tokenizer.encode(text, add_special_tokens=False)
77
  chunks = []
78
  for i in range(0, len(tokens), CHUNK_SIZE):
 
196
  headings_section = extract_possible_headings(text)
197
 
198
  progress(0.1, desc="Chunking text...")
199
+ chunks = chunk_text(text, language)
200
 
201
  summaries = []
202
  progress(0.2, desc="Summarizing chunks...")
203
 
 
 
 
 
204
  for i in progress.tqdm(range(len(chunks))):
205
  chunk = chunks[i]
206
  try:
207
+ if language == "ar_AR":
208
+ # Use mT5 for Arabic
209
+ inputs = ar_tokenizer(chunk, return_tensors="pt", max_length=512, truncation=True)
210
+ summary_ids = ar_model.generate(
211
+ inputs["input_ids"],
212
+ max_length=length_params["max"],
213
+ min_length=length_params["min"],
214
+ length_penalty=2.0,
215
+ num_beams=4,
216
+ early_stopping=True
217
+ )
218
+ summary = ar_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
219
+ else:
220
+ # Use distilbart for English
221
+ summary = en_summarizer(
222
+ chunk,
223
+ max_length=length_params["max"],
224
+ min_length=length_params["min"],
225
+ do_sample=False
226
+ )[0]["summary_text"]
227
 
228
  cleaned = clean_text(summary)
229
  chunk_label = f"**Chunk {i+1}:**" if language == "en_XX" else f"**الجزء {i+1}:**"