|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import BartTokenizer, BartForConditionalGeneration |
|
|
from peft import PeftModel |
|
|
import textstat |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
base = BartForConditionalGeneration.from_pretrained( |
|
|
"facebook/bart-large-cnn", |
|
|
torch_dtype=torch.float32, |
|
|
device_map=None |
|
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(base, "./checkpoint") |
|
|
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") |
|
|
|
|
|
model.to("cpu") |
|
|
model.eval() |
|
|
return tokenizer, model |
|
|
|
|
|
def simplify(text, tokenizer, model): |
|
|
prompt = f"simplify: {text}" |
|
|
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = model.generate(**inputs, max_new_tokens=256, num_beams=4, early_stopping=True) |
|
|
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
st.set_page_config(page_title="Legaleaze", layout="wide") |
|
|
st.title("Legaleaze: Legal Text Simplifier") |
|
|
st.caption("BART-Large + LoRA | 121k steps on asylum cases") |
|
|
|
|
|
try: |
|
|
tokenizer, model = load_model() |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.subheader("Complex Legal Text") |
|
|
text = st.text_area("", height=300, placeholder="Paste legal text here...", key="input") |
|
|
btn = st.button("Simplify", type="primary", use_container_width=True) |
|
|
|
|
|
with col2: |
|
|
st.subheader("Simplified Output") |
|
|
if btn and text.strip(): |
|
|
with st.spinner("Simplifying (30s on CPU)..."): |
|
|
result = simplify(text, tokenizer, model) |
|
|
st.session_state['result'] = result |
|
|
st.session_state['original_text'] = text |
|
|
|
|
|
if 'result' in st.session_state: |
|
|
|
|
|
simplified = st.text_area("", value=st.session_state['result'], height=300, key="output") |
|
|
|
|
|
|
|
|
if st.button("📋 Copy to Clipboard", use_container_width=True): |
|
|
st.write("Copy the text above manually (browser limitation)") |
|
|
|
|
|
st.divider() |
|
|
m1, m2, m3 = st.columns(3) |
|
|
orig = textstat.flesch_kincaid_grade(st.session_state['original_text']) |
|
|
simp = textstat.flesch_kincaid_grade(simplified) |
|
|
m1.metric("Original FKGL", f"{orig:.1f}") |
|
|
m2.metric("Simplified FKGL", f"{simp:.1f}") |
|
|
m3.metric("Improvement", f"{((orig-simp)/orig*100):.0f}%") |
|
|
else: |
|
|
st.info("Simplified text appears here") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error: {e}") |