import streamlit as st import torch from transformers import BartTokenizer, BartForConditionalGeneration from peft import PeftModel import textstat @st.cache_resource def load_model(): base = BartForConditionalGeneration.from_pretrained( "facebook/bart-large-cnn", torch_dtype=torch.float32, device_map=None ) model = PeftModel.from_pretrained(base, "./checkpoint") tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") model.to("cpu") model.eval() return tokenizer, model def simplify(text, tokenizer, model): prompt = f"simplify: {text}" inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True) with torch.inference_mode(): outputs = model.generate(**inputs, max_new_tokens=256, num_beams=4, early_stopping=True) return tokenizer.decode(outputs[0], skip_special_tokens=True) st.set_page_config(page_title="Legaleaze", layout="wide") st.title("Legaleaze: Legal Text Simplifier") st.caption("BART-Large + LoRA | 121k steps on asylum cases") try: tokenizer, model = load_model() col1, col2 = st.columns(2) with col1: st.subheader("Complex Legal Text") text = st.text_area("", height=300, placeholder="Paste legal text here...", key="input") btn = st.button("Simplify", type="primary", use_container_width=True) with col2: st.subheader("Simplified Output") if btn and text.strip(): with st.spinner("Simplifying (30s on CPU)..."): result = simplify(text, tokenizer, model) st.session_state['result'] = result st.session_state['original_text'] = text if 'result' in st.session_state: # Editable output simplified = st.text_area("", value=st.session_state['result'], height=300, key="output") # Copy button if st.button("📋 Copy to Clipboard", use_container_width=True): st.write("Copy the text above manually (browser limitation)") st.divider() m1, m2, m3 = st.columns(3) orig = textstat.flesch_kincaid_grade(st.session_state['original_text']) simp = textstat.flesch_kincaid_grade(simplified) m1.metric("Original FKGL", f"{orig:.1f}") m2.metric("Simplified FKGL", f"{simp:.1f}") m3.metric("Improvement", f"{((orig-simp)/orig*100):.0f}%") else: st.info("Simplified text appears here") except Exception as e: st.error(f"Error: {e}")