air83 / app.py
NguyenTan's picture
Update app.py
1e01b49 verified
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# ===== Model setup =====
model_path = "protonx-models/protonx-legal-tc"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
MAX_TOKENS = 128000
# ===== Inference function =====
def generate(text):
if not text or not text.strip():
return ""
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=MAX_TOKENS
).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
num_beams=10,
num_return_sequences=1,
max_new_tokens=MAX_TOKENS,
early_stopping=True
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return decoded
# ===== Gradio UI =====
with gr.Blocks(title="ProtonX Legal Text Generation") as demo:
gr.Markdown("## 🧾 ProtonX Legal Text Generation")
input_text = gr.Textbox(
label="Input text",
placeholder="Nhập nội dung pháp lý...",
lines=6
)
submit_btn = gr.Button("Submit")
output_text = gr.Textbox(
label="Output",
lines=6
)
submit_btn.click(
fn=generate,
inputs=input_text,
outputs=output_text
)
demo.launch()