import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ===== Model setup ===== model_path = "protonx-models/protonx-legal-tc" tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSeq2SeqLM.from_pretrained(model_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() MAX_TOKENS = 128000 # ===== Inference function ===== def generate(text): if not text or not text.strip(): return "" inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=MAX_TOKENS ).to(device) with torch.no_grad(): outputs = model.generate( **inputs, num_beams=10, num_return_sequences=1, max_new_tokens=MAX_TOKENS, early_stopping=True ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) return decoded # ===== Gradio UI ===== with gr.Blocks(title="ProtonX Legal Text Generation") as demo: gr.Markdown("## 🧾 ProtonX Legal Text Generation") input_text = gr.Textbox( label="Input text", placeholder="Nhập nội dung pháp lý...", lines=6 ) submit_btn = gr.Button("Submit") output_text = gr.Textbox( label="Output", lines=6 ) submit_btn.click( fn=generate, inputs=input_text, outputs=output_text ) demo.launch()