sairika commited on
Commit
9126f59
·
verified ·
1 Parent(s): a8bef18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -246
app.py CHANGED
@@ -3,267 +3,117 @@ from transformers import AutoTokenizer
3
  from peft import AutoPeftModelForSeq2SeqLM
4
  import torch
5
 
6
- # Define the Hugging Face repository ID
7
  repo_id = "sairika/FLAN-T5-Base-dialogsum-lora"
8
 
9
- # Sample dialogues for examples
10
- SAMPLE_DIALOGUES = [
11
- """Alice: Hi Bob, how was your meeting with the client today?
12
- Bob: It went really well! They loved our proposal and want to move forward.
13
- Alice: That's fantastic news! What's the next step?
14
- Bob: We need to prepare the contract and schedule a follow-up meeting next week.
15
- Alice: Great, I'll help you with the contract preparation.""",
16
 
17
- """Customer: I'm having trouble with my internet connection. It keeps dropping out.
18
- Support: I'm sorry to hear that. Let me help you troubleshoot this issue.
19
- Customer: I've already tried restarting my router.
20
- Support: Okay, let's check your signal strength. Can you tell me what lights are showing on your modem?
21
- Customer: There are green lights for power and internet, but the wifi light is blinking orange.
22
- Support: That indicates a wifi connectivity issue. Let's reset your wifi settings.""",
23
 
24
- """Mom: Have you finished your homework yet?
25
- Child: Almost done! I just have math left.
26
- Mom: Do you need any help with it?
27
- Child: Actually yes, I'm stuck on these algebra problems.
28
- Mom: Let me take a look. Oh, these are quadratic equations. Remember the formula we practiced?
29
- Child: Oh right! ax² + bx + c = 0. Thanks mom!"""
30
  ]
31
 
32
- # Load the tokenizer and the PEFT model
33
- try:
34
- tokenizer = AutoTokenizer.from_pretrained(repo_id)
35
- model = AutoPeftModelForSeq2SeqLM.from_pretrained(
36
- repo_id,
37
- device_map="auto",
38
- torch_dtype=torch.bfloat16
39
- )
40
- print(f"✅ Model and tokenizer loaded successfully from {repo_id}")
41
-
42
- def create_prompts(dialogues, model_type):
43
- """Create appropriate prompts based on model type"""
44
- if model_type in ['flan-t5', 't5']:
45
- prompts = [f"Summarize the following conversation.\n\n{dialogue}\n\nSummary: "
46
- for dialogue in dialogues]
47
- else:
48
- prompts = dialogues
49
- return prompts
50
-
51
- def generate_summary(dialogue, max_length, num_beams, length_penalty):
52
- """Generates a summary for a given dialogue with customizable parameters"""
53
- if not dialogue.strip():
54
- return "⚠️ Please enter a dialogue to summarize."
55
-
56
- try:
57
- model_type = 'flan-t5' if 'flan-t5' in repo_id else 't5' if 't5' in repo_id else 'bart'
58
- prompt = create_prompts([dialogue], model_type)[0]
59
-
60
- inputs = tokenizer(
61
- prompt,
62
- return_tensors="pt",
63
- padding=True,
64
- truncation=True,
65
- max_length=512
66
- ).to(model.device)
67
-
68
- model.eval()
69
- with torch.no_grad():
70
- outputs = model.generate(
71
- **inputs,
72
- max_new_tokens=int(max_length),
73
- num_beams=int(num_beams),
74
- length_penalty=float(length_penalty),
75
- early_stopping=True,
76
- do_sample=False
77
- )
78
-
79
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
- return f"📝 **Summary:** {summary}"
81
-
82
- except Exception as e:
83
- return f"❌ Error generating summary: {str(e)}"
84
 
85
- # Custom CSS for better styling
86
- custom_css = """
87
- .gradio-container {
88
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
89
- }
90
- .header {
91
- text-align: center;
92
- background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
93
- color: white;
94
- padding: 2rem;
95
- border-radius: 10px;
96
- margin-bottom: 2rem;
97
- }
98
- .footer {
99
- text-align: center;
100
- margin-top: 2rem;
101
- padding: 1rem;
102
- background-color: #f8f9fa;
103
- border-radius: 10px;
104
- }
105
- """
106
 
107
- # Create the Gradio interface
108
- with gr.Blocks(css=custom_css, title="Dialogue Summarization Demo") as iface:
109
- # Header
110
- gr.HTML("""
111
- <div class="header">
112
- <h1>🤖 AI Dialogue Summarization</h1>
113
- <p>Transform lengthy conversations into concise, meaningful summaries using a fine-tuned FLAN-T5 model</p>
114
- </div>
115
- """)
116
-
117
- with gr.Row():
118
- with gr.Column(scale=2):
119
- dialogue_input = gr.Textbox(
120
- label="📝 Enter Your Dialogue",
121
- placeholder="Paste your conversation here...\n\nExample:\nPerson A: Hello, how are you?\nPerson B: I'm doing well, thanks for asking!",
122
- lines=10,
123
- max_lines=20
124
- )
125
-
126
- with gr.Row():
127
- summarize_btn = gr.Button("✨ Generate Summary", variant="primary", size="lg")
128
- clear_btn = gr.Button("🗑️ Clear", variant="secondary")
129
-
130
- with gr.Column(scale=1):
131
- gr.Markdown("### ⚙️ Generation Parameters")
132
- max_length = gr.Slider(
133
- minimum=50,
134
- maximum=256,
135
- value=128,
136
- step=10,
137
- label="Max Summary Length",
138
- info="Maximum number of tokens in the summary"
139
- )
140
- num_beams = gr.Slider(
141
- minimum=1,
142
- maximum=8,
143
- value=4,
144
- step=1,
145
- label="Number of Beams",
146
- info="Higher values = better quality but slower"
147
- )
148
- length_penalty = gr.Slider(
149
- minimum=0.1,
150
- maximum=2.0,
151
- value=0.6,
152
- step=0.1,
153
- label="Length Penalty",
154
- info="Controls summary length preference"
155
- )
156
-
157
- # Output
158
- summary_output = gr.Textbox(
159
- label="🎯 Generated Summary",
160
- lines=5,
161
- show_copy_button=True
162
- )
163
-
164
- # Examples section
165
- gr.Markdown("### 💡 Try These Examples")
166
- gr.Examples(
167
- examples=SAMPLE_DIALOGUES,
168
- inputs=dialogue_input,
169
- label="Click on any example to load it:"
170
- )
171
 
172
- # Model information
173
- gr.Markdown(f"""
174
- ### 📊 Model Information
175
- - **Model:** {repo_id}
176
- - **Base Architecture:** FLAN-T5-Base with LoRA fine-tuning
177
- - **Task:** Dialogue Summarization
178
- - **Training Dataset:** DialogSum
179
- """)
180
 
181
- # Footer
182
- gr.HTML("""
183
- <div class="footer">
184
- <p>Built with ❤️ using Gradio and Hugging Face Transformers</p>
185
- <p><em>This demo showcases AI-powered dialogue summarization capabilities</em></p>
186
- </div>
187
- """)
 
 
 
 
188
 
189
- # Event handlers
190
- summarize_btn.click(
191
- fn=generate_summary,
192
- inputs=[dialogue_input, max_length, num_beams, length_penalty],
193
- outputs=summary_output
194
- )
195
 
196
- clear_btn.click(
197
- fn=lambda: ("", ""),
198
- outputs=[dialogue_input, summary_output]
199
- )
200
-
201
- # Launch the interface
202
- iface.launch(
203
- share=True, # Creates a public link
204
- show_error=True,
205
- show_api=False,
206
- favicon_path=None,
207
- ssl_verify=False
208
- )
209
 
210
- except Exception as e:
211
- print(f"❌ Error loading model or setting up Gradio: {e}")
 
 
 
 
 
212
 
213
- # Enhanced error interface
214
- def error_message(dialogue):
215
- return f"""
216
- ❌ **Error Loading Model**
217
-
218
- Unfortunately, there was an error loading the model from `{repo_id}`.
219
-
220
- **Possible causes:**
221
- - Model repository not accessible
222
- - Insufficient memory/resources
223
- - Network connectivity issues
224
- - Invalid model format
225
-
226
- **Error details:** {str(e)}
227
 
228
- Please check the logs and try again later.
229
- """
 
 
 
 
230
 
231
- custom_css_error = """
232
- .error-container {
233
- background-color: #fee;
234
- border: 2px solid #fcc;
235
- border-radius: 10px;
236
- padding: 2rem;
237
- margin: 2rem 0;
238
- }
239
- """
240
 
241
- with gr.Blocks(css=custom_css_error, title="Model Loading Error") as error_iface:
242
- gr.HTML("""
243
- <div class="error-container">
244
- <h1>🚨 Model Loading Error</h1>
245
- <p>There was an issue loading the dialogue summarization model.</p>
246
- </div>
247
- """)
248
-
249
- dialogue_input = gr.Textbox(
250
- label="Enter Dialogue",
251
- placeholder="The model could not be loaded...",
252
- interactive=False
253
- )
254
-
255
- error_output = gr.Textbox(
256
- label="Error Details",
257
- value=error_message(""),
258
- interactive=False
259
- )
260
-
261
- gr.Markdown(f"""
262
- ### 🔧 Troubleshooting Steps:
263
- 1. Check if the model repository `{repo_id}` exists and is accessible
264
- 2. Verify you have sufficient system resources (RAM/GPU)
265
- 3. Ensure all required dependencies are installed
266
- 4. Check network connectivity to Hugging Face Hub
267
- """)
268
 
269
- error_iface.launch(share=True)
 
 
 
 
 
 
 
 
3
  from peft import AutoPeftModelForSeq2SeqLM
4
  import torch
5
 
6
+ # Model configuration
7
  repo_id = "sairika/FLAN-T5-Base-dialogsum-lora"
8
 
9
+ # Sample dialogues
10
+ examples = [
11
+ "Alice: How was your meeting today?\nBob: It went great! The client approved our proposal.\nAlice: That's wonderful news!\nBob: Yes, we start next Monday.",
 
 
 
 
12
 
13
+ "Customer: My internet isn't working.\nSupport: I can help with that. Have you tried restarting your router?\nCustomer: Yes, but it's still not working.\nSupport: Let me check your connection status.",
 
 
 
 
 
14
 
15
+ "Teacher: Did you finish your homework?\nStudent: Almost done, just the math problems left.\nTeacher: Need help with anything?\nStudent: Yes, I'm stuck on question 5."
 
 
 
 
 
16
  ]
17
 
18
+ def load_model():
19
+ """Load the model and tokenizer"""
20
+ try:
21
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
22
+ model = AutoPeftModelForSeq2SeqLM.from_pretrained(
23
+ repo_id,
24
+ device_map="auto",
25
+ torch_dtype=torch.bfloat16
26
+ )
27
+ return tokenizer, model, True
28
+ except Exception as e:
29
+ print(f"Error loading model: {e}")
30
+ return None, None, False
31
+
32
+ def summarize_dialogue(dialogue):
33
+ """Generate summary for the dialogue"""
34
+ if not dialogue.strip():
35
+ return "Please enter a dialogue to summarize."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ if not model_loaded:
38
+ return "Error: Model not loaded. Please check the logs."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ try:
41
+ # Create prompt
42
+ prompt = f"Summarize the following conversation.\n\n{dialogue}\n\nSummary: "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # Tokenize input
45
+ inputs = tokenizer(
46
+ prompt,
47
+ return_tensors="pt",
48
+ padding=True,
49
+ truncation=True,
50
+ max_length=512
51
+ ).to(model.device)
52
 
53
+ # Generate summary
54
+ model.eval()
55
+ with torch.no_grad():
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=128,
59
+ num_beams=4,
60
+ length_penalty=0.6,
61
+ early_stopping=True,
62
+ do_sample=False
63
+ )
64
 
65
+ # Decode and return
66
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
67
+ return summary
 
 
 
68
 
69
+ except Exception as e:
70
+ return f"Error generating summary: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # Load model
73
+ tokenizer, model, model_loaded = load_model()
74
+
75
+ # Create Gradio interface
76
+ with gr.Blocks(title="Dialogue Summarization") as demo:
77
+ gr.Markdown("# 🤖 Dialogue Summarization")
78
+ gr.Markdown("Enter a conversation below to generate an AI summary using FLAN-T5.")
79
 
80
+ with gr.Row():
81
+ with gr.Column():
82
+ dialogue_input = gr.Textbox(
83
+ label="Dialogue",
84
+ placeholder="Enter your conversation here...",
85
+ lines=8
86
+ )
87
+
88
+ submit_btn = gr.Button("Generate Summary", variant="primary")
89
+ clear_btn = gr.Button("Clear")
 
 
 
 
90
 
91
+ with gr.Column():
92
+ summary_output = gr.Textbox(
93
+ label="Summary",
94
+ lines=6,
95
+ show_copy_button=True
96
+ )
97
 
98
+ # Examples
99
+ gr.Examples(
100
+ examples=examples,
101
+ inputs=dialogue_input,
102
+ label="Try these examples:"
103
+ )
 
 
 
104
 
105
+ # Event handlers
106
+ submit_btn.click(
107
+ fn=summarize_dialogue,
108
+ inputs=dialogue_input,
109
+ outputs=summary_output
110
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ clear_btn.click(
113
+ fn=lambda: ("", ""),
114
+ outputs=[dialogue_input, summary_output]
115
+ )
116
+
117
+ # Launch
118
+ if __name__ == "__main__":
119
+ demo.launch(share=True)