Thuyba's picture
Upload app.py
87313d8 verified
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1vrXPgckgnH7_gNK5PBBBfI2BIAk44rC0
"""
# file: gradio_swahili_translation_app.py
import re
import torch
import gradio as gr
from langdetect import detect, DetectorFactory
from transformers import (
MBartForConditionalGeneration,
MBart50TokenizerFast,
MT5Tokenizer,
AutoModelForSeq2SeqLM,
MarianMTModel,
MarianTokenizer
)
# reproducible langdetect
DetectorFactory.seed = 0
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------
# Load Summarization Models (your existing models)
# ------------------------
# (these are left as you had them — keep them if you use summarization)
mbart_model = MBartForConditionalGeneration.from_pretrained(
"Thuyba/swahili-summarization-mbart"
).to(device)
mbart_tokenizer = MBart50TokenizerFast.from_pretrained(
"Thuyba/swahili-summarization-mbart"
)
mt5_model = AutoModelForSeq2SeqLM.from_pretrained(
"Thuyba/swahili-summarization-mt5"
).to(device)
mt5_tokenizer = MT5Tokenizer.from_pretrained(
"Thuyba/swahili-summarization-mt5"
)
# ------------------------
# Load Translation Models (Helsinki Marian models you had)
# ------------------------
sw_en_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-swc-en").to(device)
sw_en_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-swc-en")
en_sw_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-swc").to(device)
en_sw_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-swc")
# ------------------------
# Utilities: splitting & chunking
# ------------------------
def split_into_sentences(text):
"""
Split by punctuation marks (., !, ?) while keeping abbreviations relatively safe.
"""
if not text or not text.strip():
return []
# basic sentence splitter — good enough for news/text
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
return [s.strip() for s in sentences if s.strip()]
def chunk_sentences(sentences, max_chars=400):
"""
Group sentences into chunks where each chunk length <= max_chars (approx).
This keeps chunks reasonable for small translation models.
"""
chunks = []
current = ""
for s in sentences:
if not current:
current = s
elif len(current) + 1 + len(s) <= max_chars:
current = current + " " + s
else:
chunks.append(current.strip())
current = s
if current:
chunks.append(current.strip())
return chunks
# ------------------------
# Translation helpers
# ------------------------
def detect_lang_safe(text, default="sw"):
"""
Detect language of text with langdetect but fallback to default on failure.
"""
try:
lang = detect(text)
return lang
except Exception:
return default
def translate_chunk(chunk, src_is_swahili):
"""
Translate a single chunk. If src_is_swahili True => sw -> en, else en -> sw.
"""
if src_is_swahili:
tokenizer = sw_en_tokenizer
model = sw_en_model
else:
tokenizer = en_sw_tokenizer
model = en_sw_model
# Tokenize and generate (truncation True to avoid errors)
inputs = tokenizer(chunk, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
out_ids = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True)
decoded = tokenizer.decode(out_ids[0], skip_special_tokens=True)
return decoded
def better_translate_text(text, assume_lang=None, max_chunk_chars=400):
"""
Main improved translator:
- splits into sentences
- groups into chunks to avoid too-long inputs
- detects language per chunk (or uses assume_lang)
- translates chunk-by-chunk and joins result
"""
if not text or not text.strip():
return ""
sentences = split_into_sentences(text)
if not sentences:
# fallback: treat whole text as one chunk
chunks = [text.strip()]
else:
chunks = chunk_sentences(sentences, max_chars=max_chunk_chars)
translated_chunks = []
for chunk in chunks:
# detect language for each chunk to be robust
lang = detect_lang_safe(chunk, default=assume_lang or "sw")
src_is_swahili = (lang == "sw")
# If language detection seems ambiguous (e.g., 'en' vs 'sw' in short text),
# try a heuristic: presence of common Swahili words -> treat as sw
if not src_is_swahili:
sw_common = {"na","kwa","ya","katika","kwa","raia","alikuwa","ame","alifariki","makamu","rais"}
if any(w in chunk.lower().split() for w in sw_common):
src_is_swahili = True
try:
translated = translate_chunk(chunk, src_is_swahili)
except Exception as e:
# on any generation failure, return original chunk (safer than gibberish)
translated = chunk
translated_chunks.append(translated)
# join with spacing — preserve basic paragraphing if original had multiple paragraphs
result = " ".join(translated_chunks)
return result
# ------------------------
# Summarization functions (unchanged)
# ------------------------
def summarize_text(input_text, model_choice):
if not input_text or not input_text.strip():
return ""
if model_choice == "mBART-50":
tokenizer = mbart_tokenizer
model = mbart_model
tokenizer.src_lang = "sw_KE"
tokenizer.tgt_lang = "sw_KE"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
with torch.no_grad():
summary_ids = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
elif model_choice == "mT5":
tokenizer = mt5_tokenizer
model = mt5_model
input_ids = tokenizer("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True).input_ids.to(device)
with torch.no_grad():
output_ids = model.generate(input_ids=input_ids, max_length=128, num_beams=4, early_stopping=True)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
return "Model not supported."
# ------------------------
# Combined translate both input & output
# ------------------------
def translate_both(input_text, output_text):
# We'll translate both fields independently using the robust pipeline
# For speed, detect language of input_text once and pass as assume_lang
if input_text and input_text.strip():
input_lang = detect_lang_safe(input_text, default=None)
else:
input_lang = None
new_input = better_translate_text(input_text or "", assume_lang=input_lang)
# For output_text (which might be summary in Swahili), we detect separately
if output_text and output_text.strip():
out_lang = detect_lang_safe(output_text, default=None)
else:
out_lang = None
new_output = better_translate_text(output_text or "", assume_lang=out_lang)
return new_input, new_output
# ------------------------
# Gradio UI
# ------------------------
css = """
.orange-btn { background-color: orange !important; color: white !important; font-weight: 600; }
.gradio-container { max-width: 980px; margin: auto; }
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("## Kiswahili Text Summarization")
with gr.Row():
input_text = gr.Textbox(lines=8, placeholder="Andika Maandishi / write Paragraph here...", label="Maandishi ya Kuingiza")
output_text = gr.Textbox(lines=8, label="Muhtasari / Tafsiri")
model_choice = gr.Dropdown(choices=["mBART-50", "mT5"], label="Choose Model/ Chagua Modeli ya Summarization", value="mBART-50")
with gr.Row():
summarize_btn = gr.Button("Summarize", elem_classes="orange-btn")
translate_btn = gr.Button("Translate / Rudisha Kiswahili", elem_classes="orange-btn")
clear_btn = gr.Button("Clear", elem_classes="orange-btn")
summarize_btn.click(summarize_text, inputs=[input_text, model_choice], outputs=output_text)
translate_btn.click(translate_both, inputs=[input_text, output_text], outputs=[input_text, output_text])
clear_btn.click(lambda: ("", ""), outputs=[input_text, output_text])
demo.launch()