Spaces:
Sleeping
Sleeping
File size: 5,090 Bytes
a11f9ec 4b22381 a11f9ec 4b22381 a11f9ec 95cf44e a11f9ec 4b22381 95cf44e 53fadea 4b22381 95cf44e a11f9ec 4b22381 a11f9ec 4b22381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import (
MBartForConditionalGeneration, MBart50Tokenizer,
MT5ForConditionalGeneration, T5Tokenizer
)
import torch
from peft import PeftModel
# ==========================
# 1. Load model from Hugging Face
# ==========================
MODEL_NAME = "angkor96/khmer-mT5-news-summarization" # e.g., "Sedtha-019/khmer-summarization"
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
# model = PeftModel.from_pretrained(base, MODEL_NAME)
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"β
Model loaded successfully on {device}!")
# ==========================
# 2. Summarization function
# ==========================
def summarize_khmer_text(text, max_length=150, min_length=40):
"""
Summarize Khmer text
"""
if not text or text.strip() == "":
return "β οΈ ααΌααααα
αΌαα’ααααα / Please enter text"
if len(text.strip()) < 20:
return "β οΈ α’αααααααααΈααα / Text is too short to summarize"
try:
# Tokenize input
inputs = tokenizer(
text,
max_length=1024,
truncation=True,
padding="max_length",
return_tensors="pt"
).to(device)
# Generate summary
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
length_penalty=2.0,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3
)
# Decode output
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
return f"β Error: {str(e)}"
# ==========================
# 3. Gradio UI
# ==========================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# π°π Khmer Text Summarization
### αααα
αΌαα’αααααααααα α αΎαααα½αααΆαααΆααααααααααααααααααααααα·
Enter Khmer text and get an automatic summary
"""
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
lines=10,
placeholder="αααα
αΌαα’αααααααααααα
ααΈααα...\nEnter Khmer text here...",
label="π α’αααααααΎα / Original Text"
)
with gr.Row():
max_len = gr.Slider(
minimum=50,
maximum=300,
value=150,
step=10,
label="Maximum Summary Length"
)
min_len = gr.Slider(
minimum=20,
maximum=100,
value=40,
step=10,
label="Minimum Summary Length"
)
submit_btn = gr.Button("π Summarize / αααααα", variant="primary")
with gr.Column():
output_text = gr.Textbox(
lines=10,
label="π αααααα / Summary"
)
# Examples
gr.Examples(
examples=[
["ααααααααααα»ααΆααΆααααααααα·ααΆαααααααΌαααααα·ααααααΌααααααααααααααα α’αΆααΆα
ααααααααααΆαααΈαα
ααααΎααααα»αααααααααΈα©αααααΈα‘α₯α α’ααααααααααΆααααΆααααααΆααααααααααα’ααα
αΆααααα½ααααααα·αααααα", 100, 30],
["ααΆαα’ααααααΆααΌαααααΆαααααΉαααααΆαααααααΆααααΆαα’αα·ααααααααΆαα·α αα·ααααΆαα»αα·ααααααααΈαααααΌααααααΆαααα’α·ααα»αα ααααΌααααααααΆααα½ααΆααΈααααΆαααααα»αααΆααααααΎαα’ααΆαααα»ααΆαα", 80, 25],
],
inputs=[input_text, max_len, min_len],
)
# Connect button
submit_btn.click(
fn=summarize_khmer_text,
inputs=[input_text, max_len, min_len],
outputs=output_text
)
# ==========================
# 4. Launch
# ==========================
if __name__ == "__main__":
demo.launch(share=True)
|