|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
|
|
|
model_path = "protonx-models/protonx-legal-tc" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
MAX_TOKENS = 128000 |
|
|
|
|
|
|
|
|
def generate(text): |
|
|
if not text or not text.strip(): |
|
|
return "" |
|
|
|
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=MAX_TOKENS |
|
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
num_beams=10, |
|
|
num_return_sequences=1, |
|
|
max_new_tokens=MAX_TOKENS, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return decoded |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="ProtonX Legal Text Generation") as demo: |
|
|
gr.Markdown("## 🧾 ProtonX Legal Text Generation") |
|
|
|
|
|
input_text = gr.Textbox( |
|
|
label="Input text", |
|
|
placeholder="Nhập nội dung pháp lý...", |
|
|
lines=6 |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("Submit") |
|
|
|
|
|
output_text = gr.Textbox( |
|
|
label="Output", |
|
|
lines=6 |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=generate, |
|
|
inputs=input_text, |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|