Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """app.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1vrXPgckgnH7_gNK5PBBBfI2BIAk44rC0 | |
| """ | |
| # file: gradio_swahili_translation_app.py | |
| import re | |
| import torch | |
| import gradio as gr | |
| from langdetect import detect, DetectorFactory | |
| from transformers import ( | |
| MBartForConditionalGeneration, | |
| MBart50TokenizerFast, | |
| MT5Tokenizer, | |
| AutoModelForSeq2SeqLM, | |
| MarianMTModel, | |
| MarianTokenizer | |
| ) | |
| # reproducible langdetect | |
| DetectorFactory.seed = 0 | |
| # Device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ------------------------ | |
| # Load Summarization Models (your existing models) | |
| # ------------------------ | |
| # (these are left as you had them — keep them if you use summarization) | |
| mbart_model = MBartForConditionalGeneration.from_pretrained( | |
| "Thuyba/swahili-summarization-mbart" | |
| ).to(device) | |
| mbart_tokenizer = MBart50TokenizerFast.from_pretrained( | |
| "Thuyba/swahili-summarization-mbart" | |
| ) | |
| mt5_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| "Thuyba/swahili-summarization-mt5" | |
| ).to(device) | |
| mt5_tokenizer = MT5Tokenizer.from_pretrained( | |
| "Thuyba/swahili-summarization-mt5" | |
| ) | |
| # ------------------------ | |
| # Load Translation Models (Helsinki Marian models you had) | |
| # ------------------------ | |
| sw_en_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-swc-en").to(device) | |
| sw_en_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-swc-en") | |
| en_sw_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-swc").to(device) | |
| en_sw_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-swc") | |
| # ------------------------ | |
| # Utilities: splitting & chunking | |
| # ------------------------ | |
| def split_into_sentences(text): | |
| """ | |
| Split by punctuation marks (., !, ?) while keeping abbreviations relatively safe. | |
| """ | |
| if not text or not text.strip(): | |
| return [] | |
| # basic sentence splitter — good enough for news/text | |
| sentences = re.split(r'(?<=[.!?])\s+', text.strip()) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def chunk_sentences(sentences, max_chars=400): | |
| """ | |
| Group sentences into chunks where each chunk length <= max_chars (approx). | |
| This keeps chunks reasonable for small translation models. | |
| """ | |
| chunks = [] | |
| current = "" | |
| for s in sentences: | |
| if not current: | |
| current = s | |
| elif len(current) + 1 + len(s) <= max_chars: | |
| current = current + " " + s | |
| else: | |
| chunks.append(current.strip()) | |
| current = s | |
| if current: | |
| chunks.append(current.strip()) | |
| return chunks | |
| # ------------------------ | |
| # Translation helpers | |
| # ------------------------ | |
| def detect_lang_safe(text, default="sw"): | |
| """ | |
| Detect language of text with langdetect but fallback to default on failure. | |
| """ | |
| try: | |
| lang = detect(text) | |
| return lang | |
| except Exception: | |
| return default | |
| def translate_chunk(chunk, src_is_swahili): | |
| """ | |
| Translate a single chunk. If src_is_swahili True => sw -> en, else en -> sw. | |
| """ | |
| if src_is_swahili: | |
| tokenizer = sw_en_tokenizer | |
| model = sw_en_model | |
| else: | |
| tokenizer = en_sw_tokenizer | |
| model = en_sw_model | |
| # Tokenize and generate (truncation True to avoid errors) | |
| inputs = tokenizer(chunk, return_tensors="pt", truncation=True, padding=True).to(device) | |
| with torch.no_grad(): | |
| out_ids = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True) | |
| decoded = tokenizer.decode(out_ids[0], skip_special_tokens=True) | |
| return decoded | |
| def better_translate_text(text, assume_lang=None, max_chunk_chars=400): | |
| """ | |
| Main improved translator: | |
| - splits into sentences | |
| - groups into chunks to avoid too-long inputs | |
| - detects language per chunk (or uses assume_lang) | |
| - translates chunk-by-chunk and joins result | |
| """ | |
| if not text or not text.strip(): | |
| return "" | |
| sentences = split_into_sentences(text) | |
| if not sentences: | |
| # fallback: treat whole text as one chunk | |
| chunks = [text.strip()] | |
| else: | |
| chunks = chunk_sentences(sentences, max_chars=max_chunk_chars) | |
| translated_chunks = [] | |
| for chunk in chunks: | |
| # detect language for each chunk to be robust | |
| lang = detect_lang_safe(chunk, default=assume_lang or "sw") | |
| src_is_swahili = (lang == "sw") | |
| # If language detection seems ambiguous (e.g., 'en' vs 'sw' in short text), | |
| # try a heuristic: presence of common Swahili words -> treat as sw | |
| if not src_is_swahili: | |
| sw_common = {"na","kwa","ya","katika","kwa","raia","alikuwa","ame","alifariki","makamu","rais"} | |
| if any(w in chunk.lower().split() for w in sw_common): | |
| src_is_swahili = True | |
| try: | |
| translated = translate_chunk(chunk, src_is_swahili) | |
| except Exception as e: | |
| # on any generation failure, return original chunk (safer than gibberish) | |
| translated = chunk | |
| translated_chunks.append(translated) | |
| # join with spacing — preserve basic paragraphing if original had multiple paragraphs | |
| result = " ".join(translated_chunks) | |
| return result | |
| # ------------------------ | |
| # Summarization functions (unchanged) | |
| # ------------------------ | |
| def summarize_text(input_text, model_choice): | |
| if not input_text or not input_text.strip(): | |
| return "" | |
| if model_choice == "mBART-50": | |
| tokenizer = mbart_tokenizer | |
| model = mbart_model | |
| tokenizer.src_lang = "sw_KE" | |
| tokenizer.tgt_lang = "sw_KE" | |
| inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device) | |
| with torch.no_grad(): | |
| summary_ids = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True) | |
| return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| elif model_choice == "mT5": | |
| tokenizer = mt5_tokenizer | |
| model = mt5_model | |
| input_ids = tokenizer("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True).input_ids.to(device) | |
| with torch.no_grad(): | |
| output_ids = model.generate(input_ids=input_ids, max_length=128, num_beams=4, early_stopping=True) | |
| return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return "Model not supported." | |
| # ------------------------ | |
| # Combined translate both input & output | |
| # ------------------------ | |
| def translate_both(input_text, output_text): | |
| # We'll translate both fields independently using the robust pipeline | |
| # For speed, detect language of input_text once and pass as assume_lang | |
| if input_text and input_text.strip(): | |
| input_lang = detect_lang_safe(input_text, default=None) | |
| else: | |
| input_lang = None | |
| new_input = better_translate_text(input_text or "", assume_lang=input_lang) | |
| # For output_text (which might be summary in Swahili), we detect separately | |
| if output_text and output_text.strip(): | |
| out_lang = detect_lang_safe(output_text, default=None) | |
| else: | |
| out_lang = None | |
| new_output = better_translate_text(output_text or "", assume_lang=out_lang) | |
| return new_input, new_output | |
| # ------------------------ | |
| # Gradio UI | |
| # ------------------------ | |
| css = """ | |
| .orange-btn { background-color: orange !important; color: white !important; font-weight: 600; } | |
| .gradio-container { max-width: 980px; margin: auto; } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("## Kiswahili Text Summarization") | |
| with gr.Row(): | |
| input_text = gr.Textbox(lines=8, placeholder="Andika Maandishi / write Paragraph here...", label="Maandishi ya Kuingiza") | |
| output_text = gr.Textbox(lines=8, label="Muhtasari / Tafsiri") | |
| model_choice = gr.Dropdown(choices=["mBART-50", "mT5"], label="Choose Model/ Chagua Modeli ya Summarization", value="mBART-50") | |
| with gr.Row(): | |
| summarize_btn = gr.Button("Summarize", elem_classes="orange-btn") | |
| translate_btn = gr.Button("Translate / Rudisha Kiswahili", elem_classes="orange-btn") | |
| clear_btn = gr.Button("Clear", elem_classes="orange-btn") | |
| summarize_btn.click(summarize_text, inputs=[input_text, model_choice], outputs=output_text) | |
| translate_btn.click(translate_both, inputs=[input_text, output_text], outputs=[input_text, output_text]) | |
| clear_btn.click(lambda: ("", ""), outputs=[input_text, output_text]) | |
| demo.launch() |