Update app.py
Browse files
app.py
CHANGED
|
@@ -8,19 +8,16 @@ import torch
|
|
| 8 |
# =========================
|
| 9 |
# Model setup (CPU-safe, Multi-language)
|
| 10 |
# =========================
|
| 11 |
-
# Use
|
| 12 |
-
EN_SUMMARIZER_MODEL = "
|
| 13 |
AR_SUMMARIZER_MODEL = "csebuetnlp/mT5_multilingual_XLSum" # Multilingual (includes Arabic)
|
| 14 |
QA_MODEL = "google/flan-t5-small" # Question generation
|
| 15 |
|
| 16 |
print("Loading models... This may take a minute on first run.")
|
| 17 |
|
| 18 |
-
# English summarizer
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
model=EN_SUMMARIZER_MODEL,
|
| 22 |
-
device=-1 # CPU only
|
| 23 |
-
)
|
| 24 |
|
| 25 |
# Multilingual summarizer (for Arabic and other languages)
|
| 26 |
ar_tokenizer = AutoTokenizer.from_pretrained(AR_SUMMARIZER_MODEL)
|
|
@@ -33,7 +30,7 @@ question_generator = pipeline(
|
|
| 33 |
device=-1 # CPU only
|
| 34 |
)
|
| 35 |
|
| 36 |
-
CHUNK_SIZE =
|
| 37 |
|
| 38 |
# =========================
|
| 39 |
# Language Detection
|
|
@@ -68,10 +65,7 @@ def clean_text(text: str) -> str:
|
|
| 68 |
def chunk_text(text: str, language: str):
|
| 69 |
"""Token-aware chunking to avoid model overflow."""
|
| 70 |
# Use appropriate tokenizer based on language
|
| 71 |
-
if language == "ar_AR"
|
| 72 |
-
tokenizer = ar_tokenizer
|
| 73 |
-
else:
|
| 74 |
-
tokenizer = AutoTokenizer.from_pretrained(EN_SUMMARIZER_MODEL)
|
| 75 |
|
| 76 |
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 77 |
chunks = []
|
|
@@ -180,17 +174,17 @@ def summarize_long_text(text: str, summary_length: str, language: str, progress=
|
|
| 180 |
if not text or len(text.strip()) == 0:
|
| 181 |
return "No text provided." if language == "en_XX" else "ูู
ูุชู
ุชูุฏูู
ูุต."
|
| 182 |
|
| 183 |
-
# Length mapping
|
| 184 |
length_map = {
|
| 185 |
-
"Short (25%)": {"max":
|
| 186 |
-
"Medium (50%)": {"max":
|
| 187 |
-
"Long (75%)": {"max": 400, "min":
|
| 188 |
-
"ูุตูุฑ (25%)": {"max":
|
| 189 |
-
"ู
ุชูุณุท (50%)": {"max":
|
| 190 |
-
"ุทููู (75%)": {"max": 400, "min":
|
| 191 |
}
|
| 192 |
|
| 193 |
-
length_params = length_map.get(summary_length, {"max":
|
| 194 |
|
| 195 |
progress(0, desc="Extracting headings...")
|
| 196 |
headings_section = extract_possible_headings(text)
|
|
@@ -205,7 +199,7 @@ def summarize_long_text(text: str, summary_length: str, language: str, progress=
|
|
| 205 |
chunk = chunks[i]
|
| 206 |
try:
|
| 207 |
if language == "ar_AR":
|
| 208 |
-
# Use mT5 for Arabic
|
| 209 |
inputs = ar_tokenizer(chunk, return_tensors="pt", max_length=512, truncation=True)
|
| 210 |
summary_ids = ar_model.generate(
|
| 211 |
inputs["input_ids"],
|
|
@@ -217,26 +211,34 @@ def summarize_long_text(text: str, summary_length: str, language: str, progress=
|
|
| 217 |
)
|
| 218 |
summary = ar_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 219 |
else:
|
| 220 |
-
# Use
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
| 223 |
max_length=length_params["max"],
|
| 224 |
min_length=length_params["min"],
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
|
| 228 |
cleaned = clean_text(summary)
|
| 229 |
-
|
| 230 |
-
|
|
|
|
| 231 |
except Exception as e:
|
| 232 |
print(f"Error in chunk {i}: {str(e)}")
|
| 233 |
-
|
| 234 |
|
| 235 |
# Format summaries
|
| 236 |
header = "### ๐ Detailed Summary\n\n" if language == "en_XX" else "### ๐ ู
ูุฎุต ุชูุตููู\n\n"
|
| 237 |
summary_md = header
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
progress(0.8, desc="Generating questions...")
|
| 242 |
questions = generate_questions(summary_md, language)
|
|
|
|
| 8 |
# =========================
|
| 9 |
# Model setup (CPU-safe, Multi-language)
|
| 10 |
# =========================
|
| 11 |
+
# Use T5-based models that support text2text-generation
|
| 12 |
+
EN_SUMMARIZER_MODEL = "google/flan-t5-base" # English - works with text2text
|
| 13 |
AR_SUMMARIZER_MODEL = "csebuetnlp/mT5_multilingual_XLSum" # Multilingual (includes Arabic)
|
| 14 |
QA_MODEL = "google/flan-t5-small" # Question generation
|
| 15 |
|
| 16 |
print("Loading models... This may take a minute on first run.")
|
| 17 |
|
| 18 |
+
# English summarizer using text2text-generation
|
| 19 |
+
en_tokenizer = AutoTokenizer.from_pretrained(EN_SUMMARIZER_MODEL)
|
| 20 |
+
en_model = AutoModelForSeq2SeqLM.from_pretrained(EN_SUMMARIZER_MODEL)
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Multilingual summarizer (for Arabic and other languages)
|
| 23 |
ar_tokenizer = AutoTokenizer.from_pretrained(AR_SUMMARIZER_MODEL)
|
|
|
|
| 30 |
device=-1 # CPU only
|
| 31 |
)
|
| 32 |
|
| 33 |
+
CHUNK_SIZE = 400 # Conservative chunk size for T5 models
|
| 34 |
|
| 35 |
# =========================
|
| 36 |
# Language Detection
|
|
|
|
| 65 |
def chunk_text(text: str, language: str):
|
| 66 |
"""Token-aware chunking to avoid model overflow."""
|
| 67 |
# Use appropriate tokenizer based on language
|
| 68 |
+
tokenizer = ar_tokenizer if language == "ar_AR" else en_tokenizer
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 71 |
chunks = []
|
|
|
|
| 174 |
if not text or len(text.strip()) == 0:
|
| 175 |
return "No text provided." if language == "en_XX" else "ูู
ูุชู
ุชูุฏูู
ูุต."
|
| 176 |
|
| 177 |
+
# Length mapping (for T5 models, these are approximate)
|
| 178 |
length_map = {
|
| 179 |
+
"Short (25%)": {"max": 128, "min": 30},
|
| 180 |
+
"Medium (50%)": {"max": 256, "min": 60},
|
| 181 |
+
"Long (75%)": {"max": 400, "min": 100},
|
| 182 |
+
"ูุตูุฑ (25%)": {"max": 128, "min": 30},
|
| 183 |
+
"ู
ุชูุณุท (50%)": {"max": 256, "min": 60},
|
| 184 |
+
"ุทููู (75%)": {"max": 400, "min": 100}
|
| 185 |
}
|
| 186 |
|
| 187 |
+
length_params = length_map.get(summary_length, {"max": 256, "min": 60})
|
| 188 |
|
| 189 |
progress(0, desc="Extracting headings...")
|
| 190 |
headings_section = extract_possible_headings(text)
|
|
|
|
| 199 |
chunk = chunks[i]
|
| 200 |
try:
|
| 201 |
if language == "ar_AR":
|
| 202 |
+
# Use mT5 for Arabic with direct model inference
|
| 203 |
inputs = ar_tokenizer(chunk, return_tensors="pt", max_length=512, truncation=True)
|
| 204 |
summary_ids = ar_model.generate(
|
| 205 |
inputs["input_ids"],
|
|
|
|
| 211 |
)
|
| 212 |
summary = ar_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 213 |
else:
|
| 214 |
+
# Use FLAN-T5 for English with summarization prompt
|
| 215 |
+
prompt = f"Summarize the following text in detail:\n\n{chunk}"
|
| 216 |
+
inputs = en_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
| 217 |
+
summary_ids = en_model.generate(
|
| 218 |
+
inputs["input_ids"],
|
| 219 |
max_length=length_params["max"],
|
| 220 |
min_length=length_params["min"],
|
| 221 |
+
num_beams=4,
|
| 222 |
+
early_stopping=True
|
| 223 |
+
)
|
| 224 |
+
summary = en_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 225 |
|
| 226 |
cleaned = clean_text(summary)
|
| 227 |
+
if cleaned: # Only add non-empty summaries
|
| 228 |
+
chunk_label = f"**Chunk {i+1}:**" if language == "en_XX" else f"**ุงูุฌุฒุก {i+1}:**"
|
| 229 |
+
summaries.append(f"{chunk_label} {cleaned}")
|
| 230 |
except Exception as e:
|
| 231 |
print(f"Error in chunk {i}: {str(e)}")
|
| 232 |
+
continue # skip problematic chunks
|
| 233 |
|
| 234 |
# Format summaries
|
| 235 |
header = "### ๐ Detailed Summary\n\n" if language == "en_XX" else "### ๐ ู
ูุฎุต ุชูุตููู\n\n"
|
| 236 |
summary_md = header
|
| 237 |
+
if summaries:
|
| 238 |
+
for s in summaries:
|
| 239 |
+
summary_md += f"- {s}\n"
|
| 240 |
+
else:
|
| 241 |
+
summary_md += "Unable to generate summary. Please try with different text.\n"
|
| 242 |
|
| 243 |
progress(0.8, desc="Generating questions...")
|
| 244 |
questions = generate_questions(summary_md, language)
|