Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from transformers import ( | |
| MBartForConditionalGeneration, MBart50Tokenizer, | |
| MT5ForConditionalGeneration, T5Tokenizer | |
| ) | |
| import torch | |
| from peft import PeftModel | |
| # ========================== | |
| # 1. Load model from Hugging Face | |
| # ========================== | |
| MODEL_NAME = "angkor96/khmer-mT5-news-summarization" # e.g., "Sedtha-019/khmer-summarization" | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| # base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50") | |
| # model = PeftModel.from_pretrained(base, MODEL_NAME) | |
| # Move to GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| print(f"β Model loaded successfully on {device}!") | |
| # ========================== | |
| # 2. Summarization function | |
| # ========================== | |
| def summarize_khmer_text(text, max_length=150, min_length=40): | |
| """ | |
| Summarize Khmer text | |
| """ | |
| if not text or text.strip() == "": | |
| return "β οΈ ααΌααααα αΌαα’ααααα / Please enter text" | |
| if len(text.strip()) < 20: | |
| return "β οΈ α’αααααααααΈααα / Text is too short to summarize" | |
| try: | |
| # Tokenize input | |
| inputs = tokenizer( | |
| text, | |
| max_length=1024, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt" | |
| ).to(device) | |
| # Generate summary | |
| with torch.no_grad(): | |
| summary_ids = model.generate( | |
| inputs["input_ids"], | |
| max_length=max_length, | |
| min_length=min_length, | |
| length_penalty=2.0, | |
| num_beams=4, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3 | |
| ) | |
| # Decode output | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| return summary | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| # ========================== | |
| # 3. Gradio UI | |
| # ========================== | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π°π Khmer Text Summarization | |
| ### αααα αΌαα’αααααααααα α αΎαααα½αααΆαααΆααααααααααααααααααααααα· | |
| Enter Khmer text and get an automatic summary | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| lines=10, | |
| placeholder="αααα αΌαα’αααααααααααα ααΈααα...\nEnter Khmer text here...", | |
| label="π α’αααααααΎα / Original Text" | |
| ) | |
| with gr.Row(): | |
| max_len = gr.Slider( | |
| minimum=50, | |
| maximum=300, | |
| value=150, | |
| step=10, | |
| label="Maximum Summary Length" | |
| ) | |
| min_len = gr.Slider( | |
| minimum=20, | |
| maximum=100, | |
| value=40, | |
| step=10, | |
| label="Minimum Summary Length" | |
| ) | |
| submit_btn = gr.Button("π Summarize / αααααα", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| lines=10, | |
| label="π αααααα / Summary" | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["ααααααααααα»ααΆααΆααααααααα·ααΆαααααααΌαααααα·ααααααΌααααααααααααααα α’αΆααΆα ααααααααααΆαααΈαα ααααΎααααα»αααααααααΈα©αααααΈα‘α₯α α’ααααααααααΆααααΆααααααΆααααααααααα’ααα αΆααααα½ααααααα·αααααα", 100, 30], | |
| ["ααΆαα’ααααααΆααΌαααααΆαααααΉαααααΆαααααααΆααααΆαα’αα·ααααααααΆαα·α αα·ααααΆαα»αα·ααααααααΈαααααΌααααααΆαααα’α·ααα»αα ααααΌααααααααΆααα½ααΆααΈααααΆαααααα»αααΆααααααΎαα’ααΆαααα»ααΆαα", 80, 25], | |
| ], | |
| inputs=[input_text, max_len, min_len], | |
| ) | |
| # Connect button | |
| submit_btn.click( | |
| fn=summarize_khmer_text, | |
| inputs=[input_text, max_len, min_len], | |
| outputs=output_text | |
| ) | |
| # ========================== | |
| # 4. Launch | |
| # ========================== | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |