Update app.py
Browse files
app.py
CHANGED
|
@@ -8,22 +8,24 @@ import torch
|
|
| 8 |
# =========================
|
| 9 |
# Model setup (CPU-safe, Multi-language)
|
| 10 |
# =========================
|
| 11 |
-
# Use
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
print("Loading models... This may take a minute on first run.")
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL)
|
| 20 |
-
summarizer = pipeline(
|
| 21 |
"summarization",
|
| 22 |
-
model=
|
| 23 |
-
tokenizer=summarizer_tokenizer,
|
| 24 |
device=-1 # CPU only
|
| 25 |
)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Question generator
|
| 28 |
question_generator = pipeline(
|
| 29 |
"text2text-generation",
|
|
@@ -31,7 +33,7 @@ question_generator = pipeline(
|
|
| 31 |
device=-1 # CPU only
|
| 32 |
)
|
| 33 |
|
| 34 |
-
CHUNK_SIZE = 512 # Conservative
|
| 35 |
|
| 36 |
# =========================
|
| 37 |
# Language Detection
|
|
@@ -63,8 +65,14 @@ def clean_text(text: str) -> str:
|
|
| 63 |
result.append(s.strip())
|
| 64 |
return " ".join(result)
|
| 65 |
|
| 66 |
-
def chunk_text(text: str,
|
| 67 |
"""Token-aware chunking to avoid model overflow."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 69 |
chunks = []
|
| 70 |
for i in range(0, len(tokens), CHUNK_SIZE):
|
|
@@ -188,28 +196,34 @@ def summarize_long_text(text: str, summary_length: str, language: str, progress=
|
|
| 188 |
headings_section = extract_possible_headings(text)
|
| 189 |
|
| 190 |
progress(0.1, desc="Chunking text...")
|
| 191 |
-
chunks = chunk_text(text,
|
| 192 |
|
| 193 |
summaries = []
|
| 194 |
progress(0.2, desc="Summarizing chunks...")
|
| 195 |
|
| 196 |
-
# Set language tokens for mBART
|
| 197 |
-
src_lang = language
|
| 198 |
-
tgt_lang = language
|
| 199 |
-
|
| 200 |
for i in progress.tqdm(range(len(chunks))):
|
| 201 |
chunk = chunks[i]
|
| 202 |
try:
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
cleaned = clean_text(summary)
|
| 215 |
chunk_label = f"**Chunk {i+1}:**" if language == "en_XX" else f"**الجزء {i+1}:**"
|
|
|
|
| 8 |
# =========================
|
| 9 |
# Model setup (CPU-safe, Multi-language)
|
| 10 |
# =========================
|
| 11 |
+
# Use different models for English and Arabic
|
| 12 |
+
EN_SUMMARIZER_MODEL = "sshleifer/distilbart-cnn-12-6" # English summarization
|
| 13 |
+
AR_SUMMARIZER_MODEL = "csebuetnlp/mT5_multilingual_XLSum" # Multilingual (includes Arabic)
|
| 14 |
+
QA_MODEL = "google/flan-t5-small" # Question generation
|
| 15 |
|
| 16 |
print("Loading models... This may take a minute on first run.")
|
| 17 |
|
| 18 |
+
# English summarizer
|
| 19 |
+
en_summarizer = pipeline(
|
|
|
|
|
|
|
| 20 |
"summarization",
|
| 21 |
+
model=EN_SUMMARIZER_MODEL,
|
|
|
|
| 22 |
device=-1 # CPU only
|
| 23 |
)
|
| 24 |
|
| 25 |
+
# Multilingual summarizer (for Arabic and other languages)
|
| 26 |
+
ar_tokenizer = AutoTokenizer.from_pretrained(AR_SUMMARIZER_MODEL)
|
| 27 |
+
ar_model = AutoModelForSeq2SeqLM.from_pretrained(AR_SUMMARIZER_MODEL)
|
| 28 |
+
|
| 29 |
# Question generator
|
| 30 |
question_generator = pipeline(
|
| 31 |
"text2text-generation",
|
|
|
|
| 33 |
device=-1 # CPU only
|
| 34 |
)
|
| 35 |
|
| 36 |
+
CHUNK_SIZE = 512 # Conservative chunk size
|
| 37 |
|
| 38 |
# =========================
|
| 39 |
# Language Detection
|
|
|
|
| 65 |
result.append(s.strip())
|
| 66 |
return " ".join(result)
|
| 67 |
|
| 68 |
+
def chunk_text(text: str, language: str):
|
| 69 |
"""Token-aware chunking to avoid model overflow."""
|
| 70 |
+
# Use appropriate tokenizer based on language
|
| 71 |
+
if language == "ar_AR":
|
| 72 |
+
tokenizer = ar_tokenizer
|
| 73 |
+
else:
|
| 74 |
+
tokenizer = AutoTokenizer.from_pretrained(EN_SUMMARIZER_MODEL)
|
| 75 |
+
|
| 76 |
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 77 |
chunks = []
|
| 78 |
for i in range(0, len(tokens), CHUNK_SIZE):
|
|
|
|
| 196 |
headings_section = extract_possible_headings(text)
|
| 197 |
|
| 198 |
progress(0.1, desc="Chunking text...")
|
| 199 |
+
chunks = chunk_text(text, language)
|
| 200 |
|
| 201 |
summaries = []
|
| 202 |
progress(0.2, desc="Summarizing chunks...")
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
for i in progress.tqdm(range(len(chunks))):
|
| 205 |
chunk = chunks[i]
|
| 206 |
try:
|
| 207 |
+
if language == "ar_AR":
|
| 208 |
+
# Use mT5 for Arabic
|
| 209 |
+
inputs = ar_tokenizer(chunk, return_tensors="pt", max_length=512, truncation=True)
|
| 210 |
+
summary_ids = ar_model.generate(
|
| 211 |
+
inputs["input_ids"],
|
| 212 |
+
max_length=length_params["max"],
|
| 213 |
+
min_length=length_params["min"],
|
| 214 |
+
length_penalty=2.0,
|
| 215 |
+
num_beams=4,
|
| 216 |
+
early_stopping=True
|
| 217 |
+
)
|
| 218 |
+
summary = ar_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
| 219 |
+
else:
|
| 220 |
+
# Use distilbart for English
|
| 221 |
+
summary = en_summarizer(
|
| 222 |
+
chunk,
|
| 223 |
+
max_length=length_params["max"],
|
| 224 |
+
min_length=length_params["min"],
|
| 225 |
+
do_sample=False
|
| 226 |
+
)[0]["summary_text"]
|
| 227 |
|
| 228 |
cleaned = clean_text(summary)
|
| 229 |
chunk_label = f"**Chunk {i+1}:**" if language == "en_XX" else f"**الجزء {i+1}:**"
|