Text / app.py
angkor96's picture
Update app.py
05384b4 verified
raw
history blame
4.79 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# ==========================
# 1. Load model from Hugging Face
# ==========================
MODEL_NAME = "angkor96/khmer-news-summarization" # e.g., "Sedtha-019/khmer-summarization"
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"βœ… Model loaded successfully on {device}!")
# ==========================
# 2. Summarization function
# ==========================
def summarize_khmer_text(text, max_length=150):
"""
Summarize Khmer text
"""
if not text or text.strip() == "":
return "⚠️ αžŸαžΌαž˜αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘ / Please enter text"
# if len(text.strip()) < 20:
# return "⚠️ αž’αžαŸ’αžαž”αž‘αžαŸ’αž›αžΈαž–αŸαž€ / Text is too short to summarize"
try:
# Tokenize input
inputs = tokenizer(
text,
max_length=1024,
truncation=True,
padding="max_length",
return_tensors="pt"
).to(device)
# Generate summary
with torch.no_grad():
summary_ids = model.generate(
inputs["input_ids"],
max_length=max_length,
# min_length=min_length,
length_penalty=2.0,
num_beams=5,
early_stopping=True,
# no_repeat_ngram_size=3
)
# Decode output
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
return f"❌ Error: {str(e)}"
# ==========================
# 3. Gradio UI
# ==========================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ‡°πŸ‡­ Khmer Text Summarization
### αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžš αž αžΎαž™αž‘αž‘αž½αž›αž”αžΆαž“αž€αžΆαžšαžŸαž„αŸ’αžαŸαž”αžŠαŸ„αž™αžŸαŸ’αžœαŸαž™αž”αŸ’αžšαžœαžαŸ’αžαž·
Enter Khmer text and get an automatic summary
"""
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
lines=10,
placeholder="αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αžαŸ’αž˜αŸ‚αžšαž“αŸ…αž‘αžΈαž“αŸαŸ‡...\nEnter Khmer text here...",
label="πŸ“ αž’αžαŸ’αžαž”αž‘αžŠαžΎαž˜ / Original Text"
)
with gr.Row():
max_len = gr.Slider(
minimum=50,
maximum=300,
value=150,
step=10,
label="Maximum Summary Length"
)
min_len = gr.Slider(
minimum=20,
maximum=100,
value=40,
step=10,
label="Minimum Summary Length"
)
submit_btn = gr.Button("πŸ”„ Summarize / αžŸαž„αŸ’αžαŸαž”", variant="primary")
with gr.Column():
output_text = gr.Textbox(
lines=10,
label="πŸ“‹ αžŸαž„αŸ’αžαŸαž” / Summary"
)
# Examples
gr.Examples(
examples=[
["αž”αŸ’αžšαž‘αŸαžŸαž€αž˜αŸ’αž–αž»αž‡αžΆαž˜αžΆαž“αž”αŸ’αžšαžœαžαŸ’αžαž·αžŸαžΆαžŸαŸ’αžšαŸ’αžαž™αžΌαžšαž›αž„αŸ‹αž“αž·αž„αžŸαž˜αŸ’αž”αžΌαžšαž”αŸ‚αž”αžŠαŸ„αž™αžœαž”αŸ’αž”αž’αž˜αŸŒαŸ” αž’αžΆαžŽαžΆαž…αž€αŸ’αžšαžαŸ’αž˜αŸ‚αžšαž”αžΆαž“αžšαžΈαž€αž…αž˜αŸ’αžšαžΎαž“αž€αŸ’αž“αž»αž„αžŸαžαžœαžαŸ’αžŸαž‘αžΈαŸ©αžŠαž›αŸ‹αž‘αžΈαŸ‘αŸ₯αŸ” αž’αž„αŸ’αž‚αžšαžœαžαŸ’αžαž‡αžΆαžŸαŸ’αž“αžΆαžŠαŸƒαžŸαŸ’αžαžΆαž”αžαŸ’αž™αž€αž˜αŸ’αž˜αžŠαŸαž’αžŸαŸ’αž…αžΆαžšαŸ’αž™αž˜αž½αž™αžšαž”αžŸαŸ‹αž–αž·αž—αž–αž›αŸ„αž€αŸ”", 100, 30],
["αž€αžΆαžšαž’αž”αŸ‹αžšαŸ†αž‡αžΆαž˜αžΌαž›αžŠαŸ’αž‹αžΆαž“αž‚αŸ’αžšαžΉαŸ‡αžŸαŸ†αžαžΆαž“αŸ‹αžŸαž˜αŸ’αžšαžΆαž”αŸ‹αž€αžΆαžšαž’αž—αž·αžœαžŒαŸ’αžαž“αŸαž‡αžΆαžαž·αŸ” αžŸαž·αžŸαŸ’αžŸαžΆαž“αž»αžŸαž·αžŸαŸ’αžŸαž‚αž”αŸ’αž”αžΈαžšαŸ€αž“αžŸαžΌαžαŸ’αžšαž™αŸ‰αžΆαž„αžŸαŸ’αž’αž·αžαžšαž»αŸ†αŸ” αž‚αŸ’αžšαžΌαž”αž„αŸ’αžšαŸ€αž“αž˜αžΆαž“αžαž½αž“αžΆαž‘αžΈαžŸαŸ†αžαžΆαž“αŸ‹αž€αŸ’αž“αž»αž„αž€αžΆαžšαž”αž„αŸ’αž€αžΎαžαž’αž“αžΆαž‚αžαž€αž»αž˜αžΆαžšαŸ”", 80, 25],
],
inputs=[input_text, max_len, min_len],
)
# Connect button
submit_btn.click(
fn=summarize_khmer_text,
inputs=[input_text, max_len, min_len],
outputs=output_text
)
# ==========================
# 4. Launch
# ==========================
if __name__ == "__main__":
demo.launch(share=True)