ckharche commited on
Commit
19e70cb
·
verified ·
1 Parent(s): abb6a0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -48
app.py CHANGED
@@ -6,66 +6,41 @@ import textstat
6
 
7
  @st.cache_resource
8
  def load_model():
 
9
  base = BartForConditionalGeneration.from_pretrained(
10
  "facebook/bart-large-cnn",
11
- torch_dtype=torch.float16,
12
- device_map="auto"
13
  )
 
14
  model = PeftModel.from_pretrained(base, "ckharche/legaleaze-bart-121k")
15
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
 
 
16
  model.eval()
17
  return tokenizer, model
18
 
19
  def simplify(text, tokenizer, model):
20
  prompt = f"simplify: {text}"
21
- inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(model.device)
 
22
  with torch.inference_mode():
23
- outputs = model.generate(**inputs, max_new_tokens=256, num_beams=4, early_stopping=True, length_penalty=0.75)
 
24
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
25
 
26
- # UI
27
- st.set_page_config(page_title="Legaleaze", layout="wide")
28
- st.title("Legaleaze: Legal Text Simplifier")
29
- st.caption("BART-Large + LoRA | Trained on 121k steps (53k asylum cases)")
30
 
31
- try:
32
- tokenizer, model = load_model()
33
-
34
- col1, col2 = st.columns(2)
35
-
36
- with col1:
37
- st.subheader("Complex Legal Text")
38
- legal_text = st.text_area("", height=300, placeholder="Paste legal text here...", key="input")
39
- simplify_btn = st.button("Simplify", type="primary", use_container_width=True)
40
-
41
- with col2:
42
- st.subheader("Simplified Output")
43
- if simplify_btn and legal_text.strip():
44
- with st.spinner("Simplifying..."):
45
- simplified = simplify(legal_text, tokenizer, model)
46
- st.text_area("", value=simplified, height=300, disabled=True, key="output")
47
-
48
- # Metrics
49
- st.divider()
50
- m1, m2, m3 = st.columns(3)
51
- orig_fkgl = textstat.flesch_kincaid_grade(legal_text)
52
- simp_fkgl = textstat.flesch_kincaid_grade(simplified)
53
- improvement = ((orig_fkgl - simp_fkgl) / orig_fkgl) * 100
54
-
55
- m1.metric("Original Grade Level", f"{orig_fkgl:.1f}")
56
- m2.metric("Simplified Grade Level", f"{simp_fkgl:.1f}")
57
- m3.metric("Readability ↑", f"{improvement:.0f}%", delta=f"-{orig_fkgl - simp_fkgl:.1f} grades")
58
- else:
59
- st.info("Output appears here")
60
-
61
- with st.expander("ℹ️ Model Details"):
62
- st.markdown("""
63
- - **Architecture**: BART-Large-CNN (406M params) + LoRA (16M trainable)
64
- - **Training**: 121k steps on H100/H200 GPUs (Northeastern HPC)
65
- - **Dataset**: 53k Canadian asylum case documents
66
- - **Performance**: FKGL ↓35% | BERTScore 0.89 | ROUGE-L 0.48
67
- """)
68
 
69
- except Exception as e:
70
- st.error(f"Model loading failed: {e}")
71
- st.info("This Space requires GPU runtime. Contact admin if issue persists.")
 
6
 
7
  @st.cache_resource
8
  def load_model():
9
+ # Load to CPU explicitly
10
  base = BartForConditionalGeneration.from_pretrained(
11
  "facebook/bart-large-cnn",
12
+ torch_dtype=torch.float32,
13
+ device_map=None # Don't use auto
14
  )
15
+
16
  model = PeftModel.from_pretrained(base, "ckharche/legaleaze-bart-121k")
17
  tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
18
+
19
+ model.to("cpu")
20
  model.eval()
21
  return tokenizer, model
22
 
23
  def simplify(text, tokenizer, model):
24
  prompt = f"simplify: {text}"
25
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
26
+
27
  with torch.inference_mode():
28
+ outputs = model.generate(**inputs, max_new_tokens=256, num_beams=4, early_stopping=True)
29
+
30
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
31
 
32
+ # Rest of your UI code...
33
+ st.title("âš–ï¸ Legaleaze")
34
+ tokenizer, model = load_model()
 
35
 
36
+ col1, col2 = st.columns(2)
37
+ with col1:
38
+ text = st.text_area("Complex Legal Text", height=300)
39
+ if st.button("Simplify"):
40
+ with st.spinner("Processing (20-30s on CPU)..."):
41
+ result = simplify(text, tokenizer, model)
42
+ st.session_state['result'] = result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ with col2:
45
+ if 'result' in st.session_state:
46
+ st.text_area("Simplified", st.session_state['result'], height=300)