sairika commited on
Commit
a8bef18
Β·
verified Β·
1 Parent(s): 8a9a1e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -55
app.py CHANGED
@@ -4,76 +4,266 @@ from peft import AutoPeftModelForSeq2SeqLM
4
  import torch
5
 
6
  # Define the Hugging Face repository ID
7
- repo_id = "sairika/FLAN-T5-Base-dialogsum-lora" # Replace 'your-username' with your Hugging Face username
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Load the tokenizer and the PEFT model
10
  try:
11
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
12
- model = AutoPeftModelForSeq2SeqLM.from_pretrained(repo_id, device_map="auto", torch_dtype=torch.bfloat16)
 
 
 
 
13
  print(f"βœ… Model and tokenizer loaded successfully from {repo_id}")
14
-
15
- # Define the create_prompts function (copied from the notebook)
16
  def create_prompts(dialogues, model_type):
17
  """Create appropriate prompts based on model type"""
18
  if model_type in ['flan-t5', 't5']:
19
- # T5-style models work better with explicit instructions
20
  prompts = [f"Summarize the following conversation.\n\n{dialogue}\n\nSummary: "
21
  for dialogue in dialogues]
22
  else:
23
- # BART-style models can work with direct input
24
  prompts = dialogues
25
  return prompts
26
-
27
- # Define the summary generation function
28
- def generate_summary(dialogue):
29
- """Generates a summary for a given dialogue"""
30
- model_type = 'flan-t5' if 'flan-t5' in repo_id else 't5' if 't5' in repo_id else 'bart'
31
- prompt = create_prompts([dialogue], model_type)[0]
32
-
33
- inputs = tokenizer(
34
- prompt,
35
- return_tensors="pt",
36
- padding=True,
37
- truncation=True,
38
- max_length=512
39
- ).to(model.device)
40
-
41
- model.eval()
42
- with torch.no_grad():
43
- outputs = model.generate(
44
- **inputs,
45
- max_new_tokens=128,
46
- num_beams=4,
47
- length_penalty=0.6,
48
- early_stopping=True,
49
- do_sample=False
50
- )
51
-
52
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
- return summary
54
-
55
- # Create Gradio interface
56
- iface = gr.Interface(
57
- fn=generate_summary,
58
- inputs=gr.Textbox(label="Enter Dialogue:"),
59
- outputs=gr.Textbox(label="Generated Summary:"),
60
- title="Dialogue Summarization with Fine-tuned LoRA Model",
61
- description=f"Summarize conversations using the fine-tuned model from {repo_id}"
62
- )
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Launch the interface
65
- iface.launch()
 
 
 
 
 
 
66
 
67
  except Exception as e:
68
  print(f"❌ Error loading model or setting up Gradio: {e}")
69
- # Provide a fallback or error message in the Gradio interface if possible
 
70
  def error_message(dialogue):
71
- return f"Error: Could not load the model or generate summary. Details: {e}"
72
- error_iface = gr.Interface(
73
- fn=error_message,
74
- inputs=gr.Textbox(label="Enter Dialogue:"),
75
- outputs=gr.Textbox(label="Error:"),
76
- title="Error Loading Model",
77
- description=f"There was an error loading the model from {repo_id}. Please check the logs."
78
- )
79
- error_iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import torch
5
 
6
  # Define the Hugging Face repository ID
7
+ repo_id = "sairika/FLAN-T5-Base-dialogsum-lora"
8
+
9
+ # Sample dialogues for examples
10
+ SAMPLE_DIALOGUES = [
11
+ """Alice: Hi Bob, how was your meeting with the client today?
12
+ Bob: It went really well! They loved our proposal and want to move forward.
13
+ Alice: That's fantastic news! What's the next step?
14
+ Bob: We need to prepare the contract and schedule a follow-up meeting next week.
15
+ Alice: Great, I'll help you with the contract preparation.""",
16
+
17
+ """Customer: I'm having trouble with my internet connection. It keeps dropping out.
18
+ Support: I'm sorry to hear that. Let me help you troubleshoot this issue.
19
+ Customer: I've already tried restarting my router.
20
+ Support: Okay, let's check your signal strength. Can you tell me what lights are showing on your modem?
21
+ Customer: There are green lights for power and internet, but the wifi light is blinking orange.
22
+ Support: That indicates a wifi connectivity issue. Let's reset your wifi settings.""",
23
+
24
+ """Mom: Have you finished your homework yet?
25
+ Child: Almost done! I just have math left.
26
+ Mom: Do you need any help with it?
27
+ Child: Actually yes, I'm stuck on these algebra problems.
28
+ Mom: Let me take a look. Oh, these are quadratic equations. Remember the formula we practiced?
29
+ Child: Oh right! axΒ² + bx + c = 0. Thanks mom!"""
30
+ ]
31
 
32
  # Load the tokenizer and the PEFT model
33
  try:
34
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
35
+ model = AutoPeftModelForSeq2SeqLM.from_pretrained(
36
+ repo_id,
37
+ device_map="auto",
38
+ torch_dtype=torch.bfloat16
39
+ )
40
  print(f"βœ… Model and tokenizer loaded successfully from {repo_id}")
41
+
 
42
  def create_prompts(dialogues, model_type):
43
  """Create appropriate prompts based on model type"""
44
  if model_type in ['flan-t5', 't5']:
 
45
  prompts = [f"Summarize the following conversation.\n\n{dialogue}\n\nSummary: "
46
  for dialogue in dialogues]
47
  else:
 
48
  prompts = dialogues
49
  return prompts
50
+
51
+ def generate_summary(dialogue, max_length, num_beams, length_penalty):
52
+ """Generates a summary for a given dialogue with customizable parameters"""
53
+ if not dialogue.strip():
54
+ return "⚠️ Please enter a dialogue to summarize."
55
+
56
+ try:
57
+ model_type = 'flan-t5' if 'flan-t5' in repo_id else 't5' if 't5' in repo_id else 'bart'
58
+ prompt = create_prompts([dialogue], model_type)[0]
59
+
60
+ inputs = tokenizer(
61
+ prompt,
62
+ return_tensors="pt",
63
+ padding=True,
64
+ truncation=True,
65
+ max_length=512
66
+ ).to(model.device)
67
+
68
+ model.eval()
69
+ with torch.no_grad():
70
+ outputs = model.generate(
71
+ **inputs,
72
+ max_new_tokens=int(max_length),
73
+ num_beams=int(num_beams),
74
+ length_penalty=float(length_penalty),
75
+ early_stopping=True,
76
+ do_sample=False
77
+ )
78
+
79
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+ return f"πŸ“ **Summary:** {summary}"
81
+
82
+ except Exception as e:
83
+ return f"❌ Error generating summary: {str(e)}"
84
+
85
+ # Custom CSS for better styling
86
+ custom_css = """
87
+ .gradio-container {
88
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
89
+ }
90
+ .header {
91
+ text-align: center;
92
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
93
+ color: white;
94
+ padding: 2rem;
95
+ border-radius: 10px;
96
+ margin-bottom: 2rem;
97
+ }
98
+ .footer {
99
+ text-align: center;
100
+ margin-top: 2rem;
101
+ padding: 1rem;
102
+ background-color: #f8f9fa;
103
+ border-radius: 10px;
104
+ }
105
+ """
106
+
107
+ # Create the Gradio interface
108
+ with gr.Blocks(css=custom_css, title="Dialogue Summarization Demo") as iface:
109
+ # Header
110
+ gr.HTML("""
111
+ <div class="header">
112
+ <h1>πŸ€– AI Dialogue Summarization</h1>
113
+ <p>Transform lengthy conversations into concise, meaningful summaries using a fine-tuned FLAN-T5 model</p>
114
+ </div>
115
+ """)
116
+
117
+ with gr.Row():
118
+ with gr.Column(scale=2):
119
+ dialogue_input = gr.Textbox(
120
+ label="πŸ“ Enter Your Dialogue",
121
+ placeholder="Paste your conversation here...\n\nExample:\nPerson A: Hello, how are you?\nPerson B: I'm doing well, thanks for asking!",
122
+ lines=10,
123
+ max_lines=20
124
+ )
125
+
126
+ with gr.Row():
127
+ summarize_btn = gr.Button("✨ Generate Summary", variant="primary", size="lg")
128
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
129
+
130
+ with gr.Column(scale=1):
131
+ gr.Markdown("### βš™οΈ Generation Parameters")
132
+ max_length = gr.Slider(
133
+ minimum=50,
134
+ maximum=256,
135
+ value=128,
136
+ step=10,
137
+ label="Max Summary Length",
138
+ info="Maximum number of tokens in the summary"
139
+ )
140
+ num_beams = gr.Slider(
141
+ minimum=1,
142
+ maximum=8,
143
+ value=4,
144
+ step=1,
145
+ label="Number of Beams",
146
+ info="Higher values = better quality but slower"
147
+ )
148
+ length_penalty = gr.Slider(
149
+ minimum=0.1,
150
+ maximum=2.0,
151
+ value=0.6,
152
+ step=0.1,
153
+ label="Length Penalty",
154
+ info="Controls summary length preference"
155
+ )
156
+
157
+ # Output
158
+ summary_output = gr.Textbox(
159
+ label="🎯 Generated Summary",
160
+ lines=5,
161
+ show_copy_button=True
162
+ )
163
+
164
+ # Examples section
165
+ gr.Markdown("### πŸ’‘ Try These Examples")
166
+ gr.Examples(
167
+ examples=SAMPLE_DIALOGUES,
168
+ inputs=dialogue_input,
169
+ label="Click on any example to load it:"
170
+ )
171
+
172
+ # Model information
173
+ gr.Markdown(f"""
174
+ ### πŸ“Š Model Information
175
+ - **Model:** {repo_id}
176
+ - **Base Architecture:** FLAN-T5-Base with LoRA fine-tuning
177
+ - **Task:** Dialogue Summarization
178
+ - **Training Dataset:** DialogSum
179
+ """)
180
+
181
+ # Footer
182
+ gr.HTML("""
183
+ <div class="footer">
184
+ <p>Built with ❀️ using Gradio and Hugging Face Transformers</p>
185
+ <p><em>This demo showcases AI-powered dialogue summarization capabilities</em></p>
186
+ </div>
187
+ """)
188
+
189
+ # Event handlers
190
+ summarize_btn.click(
191
+ fn=generate_summary,
192
+ inputs=[dialogue_input, max_length, num_beams, length_penalty],
193
+ outputs=summary_output
194
+ )
195
+
196
+ clear_btn.click(
197
+ fn=lambda: ("", ""),
198
+ outputs=[dialogue_input, summary_output]
199
+ )
200
+
201
  # Launch the interface
202
+ iface.launch(
203
+ share=True, # Creates a public link
204
+ show_error=True,
205
+ show_api=False,
206
+ favicon_path=None,
207
+ ssl_verify=False
208
+ )
209
 
210
  except Exception as e:
211
  print(f"❌ Error loading model or setting up Gradio: {e}")
212
+
213
+ # Enhanced error interface
214
  def error_message(dialogue):
215
+ return f"""
216
+ ❌ **Error Loading Model**
217
+
218
+ Unfortunately, there was an error loading the model from `{repo_id}`.
219
+
220
+ **Possible causes:**
221
+ - Model repository not accessible
222
+ - Insufficient memory/resources
223
+ - Network connectivity issues
224
+ - Invalid model format
225
+
226
+ **Error details:** {str(e)}
227
+
228
+ Please check the logs and try again later.
229
+ """
230
+
231
+ custom_css_error = """
232
+ .error-container {
233
+ background-color: #fee;
234
+ border: 2px solid #fcc;
235
+ border-radius: 10px;
236
+ padding: 2rem;
237
+ margin: 2rem 0;
238
+ }
239
+ """
240
+
241
+ with gr.Blocks(css=custom_css_error, title="Model Loading Error") as error_iface:
242
+ gr.HTML("""
243
+ <div class="error-container">
244
+ <h1>🚨 Model Loading Error</h1>
245
+ <p>There was an issue loading the dialogue summarization model.</p>
246
+ </div>
247
+ """)
248
+
249
+ dialogue_input = gr.Textbox(
250
+ label="Enter Dialogue",
251
+ placeholder="The model could not be loaded...",
252
+ interactive=False
253
+ )
254
+
255
+ error_output = gr.Textbox(
256
+ label="Error Details",
257
+ value=error_message(""),
258
+ interactive=False
259
+ )
260
+
261
+ gr.Markdown(f"""
262
+ ### πŸ”§ Troubleshooting Steps:
263
+ 1. Check if the model repository `{repo_id}` exists and is accessible
264
+ 2. Verify you have sufficient system resources (RAM/GPU)
265
+ 3. Ensure all required dependencies are installed
266
+ 4. Check network connectivity to Hugging Face Hub
267
+ """)
268
+
269
+ error_iface.launch(share=True)