Aditi132 commited on
Commit
3f9ad47
·
verified ·
1 Parent(s): 1b5fd52

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +143 -0
  2. train.py +105 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
4
+
5
+ MODEL_NAME = "t5-small"
6
+
7
+ print("Loading model...")
8
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, legacy=False)
9
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
+ model.eval()
13
+ print(f"Model loaded on {device}!")
14
+
15
+ def simplify_legal_text(legal_text, max_length=512, num_beams=4):
16
+ if not legal_text or not legal_text.strip():
17
+ return "Please enter some legal text to simplify."
18
+
19
+ if len(legal_text) > 5000:
20
+ return "Text too long! Please keep input under 5,000 characters."
21
+
22
+ try:
23
+ input_text = f"summarize: {legal_text}"
24
+
25
+ # ✅ FIXED: Use tokenizer as callable
26
+ encoded = tokenizer(
27
+ input_text,
28
+ max_length=1024,
29
+ truncation=True,
30
+ return_tensors="pt"
31
+ )
32
+ inputs = encoded.input_ids.to(device)
33
+
34
+ with torch.no_grad():
35
+ outputs = model.generate(
36
+ inputs,
37
+ max_length=max_length,
38
+ num_beams=num_beams,
39
+ early_stopping=True,
40
+ do_sample=False,
41
+ repetition_penalty=2.5,
42
+ length_penalty=1.0,
43
+ no_repeat_ngram_size=3
44
+ )
45
+
46
+ simplified_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+ return simplified_text
48
+
49
+ except Exception as e:
50
+ return f"Error: {str(e)}. Please try again with shorter text."
51
+
52
+
53
+
54
+ # Create Gradio interface
55
+ with gr.Blocks(title="Legal Text Simplifier", theme=gr.themes.Soft()) as demo:
56
+ gr.Markdown(
57
+ """
58
+ # ⚖️ Legal Text Simplifier
59
+
60
+ Transform complex legal language into simple, easy-to-understand text.
61
+
62
+ **How to use:**
63
+ 1. Paste your legal text in the input box
64
+ 2. Adjust settings if needed (optional)
65
+ 3. Click "Simplify" to get your simplified version
66
+
67
+ **Tips:**
68
+ - Works best with paragraphs or short documents
69
+ - For very long texts, break them into smaller sections
70
+ - The model uses AI to preserve meaning while simplifying language
71
+ """
72
+ )
73
+
74
+ with gr.Row():
75
+ with gr.Column(scale=2):
76
+ legal_input = gr.Textbox(
77
+ label="📝 Legal Text (Paste your complex legal text here)",
78
+ placeholder="Enter legal text to simplify...",
79
+ lines=10,
80
+ value="The party of the first part hereby agrees to indemnify and hold harmless the party of the second part from any and all claims, damages, losses, costs, and expenses..."
81
+ )
82
+
83
+ with gr.Row():
84
+ simplify_btn = gr.Button("✨ Simplify Text", variant="primary", size="lg")
85
+ clear_btn = gr.Button("🗑️ Clear", size="lg")
86
+
87
+ with gr.Column(scale=1):
88
+ gr.Markdown("### ⚙️ Advanced Settings")
89
+ max_length = gr.Slider(
90
+ minimum=100,
91
+ maximum=1000,
92
+ value=512,
93
+ step=50,
94
+ label="Max Output Length",
95
+ info="Longer = more detailed, but may be slower"
96
+ )
97
+ num_beams = gr.Slider(
98
+ minimum=1,
99
+ maximum=8,
100
+ value=4,
101
+ step=1,
102
+ label="Quality (Beam Search)",
103
+ info="Higher = better quality, slower generation"
104
+ )
105
+
106
+ simplified_output = gr.Textbox(
107
+ label="✨ Simplified Text",
108
+ lines=10,
109
+ interactive=False,
110
+ placeholder="Your simplified text will appear here..."
111
+ )
112
+
113
+ gr.Markdown(
114
+ """
115
+ ---
116
+ ### 💡 Example
117
+
118
+ **Input:** "The party of the first part hereby agrees to indemnify and hold harmless..."
119
+
120
+ **Output:** "The first party agrees to protect the second party from any claims or losses..."
121
+
122
+ ---
123
+ *Powered by T5 Transformer Model | Deployed for free on Hugging Face Spaces*
124
+ """
125
+ )
126
+
127
+ # Connect the function to the interface
128
+ simplify_btn.click(
129
+ fn=simplify_legal_text,
130
+ inputs=[legal_input, max_length, num_beams],
131
+ outputs=simplified_output
132
+ )
133
+
134
+ clear_btn.click(
135
+ fn=lambda: ("", ""),
136
+ outputs=[legal_input, simplified_output]
137
+ )
138
+
139
+ # ... [Gradio UI code unchanged] ...
140
+
141
+ if __name__ == "__main__":
142
+ # ✅ FIXED: No server_name/port for Spaces compatibility
143
+ demo.launch()
train.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from datasets import load_dataset
4
+ from transformers import (
5
+ T5ForConditionalGeneration,
6
+ T5Tokenizer,
7
+ Seq2SeqTrainingArguments,
8
+ Seq2SeqTrainer,
9
+ DataCollatorForSeq2Seq
10
+ )
11
+
12
+ # --- Configuration ---
13
+ MODEL_NAME = "t5-small"
14
+ OUTPUT_DIR = "./model_output"
15
+ MAX_INPUT_LENGTH = 1024
16
+ MAX_TARGET_LENGTH = 128
17
+ # We can increase batch size slightly if using GPU, but monitoring RAM is crucial
18
+ BATCH_SIZE = 8
19
+ EPOCHS = 3
20
+
21
+ def main():
22
+ # Check for GPU
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ print(f"Using device: {device}")
25
+
26
+ if device == "cuda":
27
+ print(f"GPU Name: {torch.cuda.get_device_name(0)}")
28
+ print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
29
+ else:
30
+ print("WARNING: No GPU detected. Training will be slow on CPU.")
31
+
32
+ print(f"Loading model: {MODEL_NAME}...")
33
+ try:
34
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, legacy=False)
35
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
36
+ model.to(device) # Move model to GPU immediately
37
+ except Exception as e:
38
+ print(f"Error loading model: {e}")
39
+ return
40
+
41
+ # --- Load Dataset ---
42
+ print("Loading 'billsum' dataset...")
43
+ # Using 'ca_test' for a quick cycle
44
+ dataset = load_dataset("billsum", split="ca_test")
45
+
46
+ # Let's train on slightly more data now that we have a GPU
47
+ # Splitting the 1200 ca_test examples
48
+ dataset = dataset.train_test_split(test_size=0.1)
49
+ train_dataset = dataset["train"] # Uses ~1000 examples
50
+ eval_dataset = dataset["test"] # Uses ~100 examples
51
+
52
+ print(f"Training on {len(train_dataset)} examples...")
53
+
54
+ # --- Preprocessing ---
55
+ prefix = "summarize: "
56
+
57
+ def preprocess_function(examples):
58
+ inputs = [prefix + doc for doc in examples["text"]]
59
+ model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
60
+
61
+ labels = tokenizer(text_target=examples["summary"], max_length=MAX_TARGET_LENGTH, truncation=True)
62
+ model_inputs["labels"] = labels["input_ids"]
63
+ return model_inputs
64
+
65
+ print("Tokenizing data...")
66
+ tokenized_train = train_dataset.map(preprocess_function, batched=True)
67
+ tokenized_eval = eval_dataset.map(preprocess_function, batched=True)
68
+
69
+ data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
70
+
71
+ # --- Training Args ---
72
+ training_args = Seq2SeqTrainingArguments(
73
+ output_dir=OUTPUT_DIR,
74
+ eval_strategy="epoch", # ✅ Correct for transformers >= 4.40
75
+ learning_rate=2e-5,
76
+ per_device_train_batch_size=BATCH_SIZE,
77
+ per_device_eval_batch_size=BATCH_SIZE,
78
+ weight_decay=0.01,
79
+ save_total_limit=1,
80
+ num_train_epochs=EPOCHS,
81
+ predict_with_generate=True,
82
+ fp16=(device == "cuda"), # Mixed precision on GPU
83
+ dataloader_num_workers=0, # Safe for Windows
84
+ logging_steps=10,
85
+ )
86
+
87
+ trainer = Seq2SeqTrainer(
88
+ model=model,
89
+ args=training_args,
90
+ train_dataset=tokenized_train,
91
+ eval_dataset=tokenized_eval,
92
+ tokenizer=tokenizer,
93
+ data_collator=data_collator,
94
+ )
95
+
96
+ print("Starting training...")
97
+ trainer.train()
98
+
99
+ print("Saving model...")
100
+ trainer.save_model(OUTPUT_DIR)
101
+ tokenizer.save_pretrained(OUTPUT_DIR)
102
+ print(f"Model saved to {OUTPUT_DIR}")
103
+
104
+ if __name__ == "__main__":
105
+ main()