msmaje commited on
Commit
4e7455f
·
verified ·
1 Parent(s): 1acf8f1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +469 -0
app.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import pandas as pd
4
+ import os
5
+ import tempfile
6
+ import time
7
+ import subprocess
8
+ from huggingface_hub import login, HfApi
9
+ from transformers import AutoTokenizer, BertForSequenceClassification
10
+ from datasets import load_dataset
11
+
12
+ # Global variables
13
+ MODEL_PATH = "local-model"
14
+ CATEGORIES = ['Online-Safety', 'BroadBand', 'TV-Radio']
15
+ idx_to_category = {0: 'Online-Safety', 1: 'BroadBand', 2: 'TV-Radio'}
16
+ TOKEN = None
17
+ TRAINING_LOGS = []
18
+ CURRENT_MODEL = None
19
+ CURRENT_TOKENIZER = None
20
+
21
+ def login_to_hf(token):
22
+ """Login to Hugging Face"""
23
+ global TOKEN
24
+ TOKEN = token
25
+ try:
26
+ login(token)
27
+ return "✅ Successfully logged in to Hugging Face"
28
+ except Exception as e:
29
+ return f"❌ Login failed: {str(e)}"
30
+
31
+ def validate_hub_model_id(username, model_name):
32
+ """Validate and construct Hub model ID"""
33
+ if not username or not model_name:
34
+ return None, "Please provide both username and model name"
35
+
36
+ # Clean up the model name
37
+ model_name = model_name.strip().lower().replace(" ", "-")
38
+ model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
39
+
40
+ # Construct the full model ID
41
+ hub_model_id = f"{username}/{model_name}"
42
+
43
+ return hub_model_id, None
44
+
45
+ def load_model(model_path):
46
+ """Load a trained model and tokenizer"""
47
+ global CURRENT_MODEL, CURRENT_TOKENIZER
48
+
49
+ try:
50
+ # Try loading from local path first
51
+ if os.path.exists(model_path):
52
+ CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
53
+ CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
54
+ model_path,
55
+ num_labels=len(CATEGORIES)
56
+ )
57
+ return f"✅ Model loaded from local path: {model_path}"
58
+
59
+ # If local path doesn't exist, try loading from Hub
60
+ try:
61
+ CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
62
+ CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
63
+ model_path,
64
+ num_labels=len(CATEGORIES)
65
+ )
66
+ return f"✅ Model loaded from Hugging Face Hub: {model_path}"
67
+ except Exception as hub_error:
68
+ # If both local and hub loading fail, fall back to base model
69
+ CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
70
+ CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
71
+ "bert-base-uncased",
72
+ num_labels=len(CATEGORIES)
73
+ )
74
+ return f"⚠️ Failed to load specified model, using base BERT model instead. Error: {str(hub_error)}"
75
+
76
+ except Exception as e:
77
+ return f"❌ Failed to load model: {str(e)}"
78
+
79
+ def predict_text(text, model_path):
80
+ """Make a prediction on a single text input"""
81
+ global CURRENT_MODEL, CURRENT_TOKENIZER
82
+
83
+ # Load the model if it's not loaded or a different one is requested
84
+ if CURRENT_MODEL is None or model_path != MODEL_PATH:
85
+ load_result = load_model(model_path)
86
+ if load_result.startswith("❌"):
87
+ return load_result
88
+
89
+ try:
90
+ # Tokenize input
91
+ inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
92
+
93
+ # Make prediction
94
+ with torch.no_grad():
95
+ outputs = CURRENT_MODEL(**inputs)
96
+ predicted_idx = outputs.logits.argmax().item()
97
+
98
+ # Get category from index
99
+ predicted_category = idx_to_category[predicted_idx]
100
+
101
+ # Check if text was truncated
102
+ original_tokens = CURRENT_TOKENIZER(text, truncation=False)
103
+ was_truncated = len(original_tokens['input_ids']) > 512
104
+ truncation_warning = "\n\n⚠️ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
105
+
106
+ return f"Complaint: {text}\n\nPredicted Category: {predicted_category}{truncation_warning}"
107
+ except Exception as e:
108
+ return f"❌ Prediction failed: {str(e)}"
109
+
110
+ def predict_csv(csv_file, model_path):
111
+ """Make predictions on a CSV file with complaints"""
112
+ global CURRENT_MODEL, CURRENT_TOKENIZER
113
+
114
+ # Load the model if needed
115
+ if CURRENT_MODEL is None or model_path != MODEL_PATH:
116
+ load_result = load_model(model_path)
117
+ if load_result.startswith("❌"):
118
+ return load_result
119
+
120
+ try:
121
+ # Read the CSV file
122
+ if hasattr(csv_file, 'name'):
123
+ df = pd.read_csv(csv_file.name)
124
+ else:
125
+ df = pd.read_csv(csv_file)
126
+
127
+ if 'complaint' not in df.columns:
128
+ return "❌ CSV file must have a 'complaint' column"
129
+
130
+ results = []
131
+ truncated_count = 0
132
+
133
+ for i, row in enumerate(df.iterrows()):
134
+ complaint = str(row[1]['complaint'])
135
+
136
+ # Check for truncation
137
+ original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
138
+ was_truncated = len(original_tokens['input_ids']) > 512
139
+ if was_truncated:
140
+ truncated_count += 1
141
+
142
+ # Predict
143
+ inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
144
+ with torch.no_grad():
145
+ outputs = CURRENT_MODEL(**inputs)
146
+ predicted_idx = outputs.logits.argmax().item()
147
+
148
+ predicted_category = idx_to_category[predicted_idx]
149
+
150
+ truncation_mark = " ⚠️" if was_truncated else ""
151
+ preview = complaint if len(complaint) <= 50 else complaint[:47] + "..."
152
+ results.append(f"Complaint {i+1}{truncation_mark}: {preview}\nPredicted Category: {predicted_category}\n")
153
+
154
+ if i >= 19:
155
+ results.append(f"... and {len(df) - 20} more (showing first 20 out of {len(df)} complaints)")
156
+ break
157
+
158
+ if truncated_count > 0:
159
+ results.append(f"\n⚠️ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
160
+
161
+ return "\n".join(results)
162
+ except Exception as e:
163
+ return f"❌ CSV processing failed: {str(e)}"
164
+
165
+ def train_model(dataset_name, num_epochs, batch_size, learning_rate, hf_token,
166
+ push_to_hub, username, model_name):
167
+ """Start the model training process"""
168
+ global TRAINING_LOGS, MODEL_PATH
169
+
170
+ TRAINING_LOGS = [] # Reset logs at the start of training
171
+
172
+ if hf_token:
173
+ login_result = login_to_hf(hf_token)
174
+ TRAINING_LOGS.append(login_result)
175
+ yield "\n".join(TRAINING_LOGS)
176
+
177
+ # Validate hub model ID if pushing to hub
178
+ if push_to_hub:
179
+ hub_model_id, error = validate_hub_model_id(username, model_name)
180
+ if error:
181
+ TRAINING_LOGS.append(f"❌ {error}")
182
+ yield "\n".join(TRAINING_LOGS)
183
+ return
184
+ else:
185
+ hub_model_id = None
186
+
187
+ # Create training command
188
+ cmd = [
189
+ "python", "bert_finetune.py",
190
+ "--dataset_name", dataset_name,
191
+ "--model_id", "bert-base-uncased",
192
+ "--output_dir", MODEL_PATH,
193
+ "--feature_column", "complaint",
194
+ "--label_column", "label_idx",
195
+ "--num_labels", "3",
196
+ "--num_train_epochs", str(num_epochs),
197
+ "--batch_size", str(batch_size),
198
+ "--learning_rate", str(learning_rate),
199
+ "--max_length", "512"
200
+ ]
201
+
202
+ if push_to_hub and hub_model_id:
203
+ cmd.extend(["--push_to_hub", "--hub_model_id", hub_model_id])
204
+ if hf_token:
205
+ cmd.extend(["--hf_token", hf_token])
206
+
207
+ TRAINING_LOGS.append(f"Starting training with command: {' '.join(cmd)}")
208
+ yield "\n".join(TRAINING_LOGS)
209
+
210
+ try:
211
+ process = subprocess.Popen(
212
+ cmd,
213
+ stdout=subprocess.PIPE,
214
+ stderr=subprocess.STDOUT,
215
+ universal_newlines=True,
216
+ bufsize=1
217
+ )
218
+
219
+ TRAINING_LOGS.append("Training started...")
220
+ yield "\n".join(TRAINING_LOGS)
221
+
222
+ while True:
223
+ line = process.stdout.readline()
224
+ if not line and process.poll() is not None:
225
+ break
226
+ if line:
227
+ TRAINING_LOGS.append(line.strip())
228
+ yield "\n".join(TRAINING_LOGS)
229
+
230
+ process.wait()
231
+
232
+ if process.returncode == 0:
233
+ TRAINING_LOGS.append("✅ Training completed successfully!")
234
+ if push_to_hub and hub_model_id:
235
+ TRAINING_LOGS.append(f"✅ Model pushed to Hugging Face Hub: {hub_model_id}")
236
+
237
+ # Load the trained model
238
+ TRAINING_LOGS.append("Loading trained model...")
239
+ load_result = load_model(MODEL_PATH)
240
+ TRAINING_LOGS.append(load_result)
241
+
242
+ # Final success message
243
+ TRAINING_LOGS.append("\n✨ All done! Your model is ready to use.")
244
+ else:
245
+ TRAINING_LOGS.append(f"❌ Training failed with return code {process.returncode}")
246
+
247
+ except Exception as e:
248
+ TRAINING_LOGS.append(f"❌ Error during training: {str(e)}")
249
+
250
+ yield "\n".join(TRAINING_LOGS)
251
+
252
+ def push_to_hub_after_training(model_path, username, model_name, token):
253
+ """Push a trained model to Hugging Face Hub"""
254
+ try:
255
+ if not token:
256
+ return "❌ Please provide a Hugging Face token"
257
+
258
+ hub_model_id, error = validate_hub_model_id(username, model_name)
259
+ if error:
260
+ return f"❌ {error}"
261
+
262
+ # Login and load model
263
+ login(token)
264
+ if not os.path.exists(model_path):
265
+ return "❌ No trained model found. Please train a model first."
266
+
267
+ try:
268
+ model = BertForSequenceClassification.from_pretrained(model_path)
269
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
270
+ except Exception as e:
271
+ return f"❌ Failed to load model: {str(e)}"
272
+
273
+ # Push to Hub
274
+ try:
275
+ model.push_to_hub(hub_model_id)
276
+ tokenizer.push_to_hub(hub_model_id)
277
+ return f"✅ Successfully pushed model to {hub_model_id}"
278
+ except Exception as e:
279
+ return f"❌ Failed to push to Hub: {str(e)}"
280
+
281
+ except Exception as e:
282
+ return f"❌ Error: {str(e)}"
283
+
284
+ # Create the Gradio Interface
285
+ with gr.Blocks(title="BERT Complaint Classifier") as app:
286
+ gr.Markdown("# BERT Complaint Category Classifier")
287
+ gr.Markdown("A simple tool to train and use a BERT model for classifying customer complaints")
288
+
289
+ with gr.Tabs():
290
+ # Training Tab
291
+ with gr.TabItem("Train Model"):
292
+ gr.Markdown("### Train a New Model")
293
+ gr.Markdown("Provide your dataset information and training parameters")
294
+
295
+ dataset_name = gr.Textbox(
296
+ label="Dataset Name (from Hugging Face)",
297
+ placeholder="e.g., your-username/complaint-categories-dataset"
298
+ )
299
+
300
+ with gr.Row():
301
+ num_epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs")
302
+ batch_size = gr.Slider(minimum=4, maximum=32, value=8, step=4, label="Batch Size")
303
+ learning_rate = gr.Slider(minimum=1e-5, maximum=5e-5, value=2e-5, step=1e-5, label="Learning Rate")
304
+
305
+ with gr.Accordion("Hugging Face Hub Settings", open=False):
306
+ hf_token = gr.Textbox(
307
+ label="Hugging Face Token (required for pushing to Hub)",
308
+ type="password"
309
+ )
310
+
311
+ gr.Markdown("""### Choose when to push to Hub:
312
+ 1. During Training: Model will be pushed automatically when training completes
313
+ 2. After Training: You can push the trained model manually later""")
314
+
315
+ # During Training Push
316
+ with gr.Group():
317
+ push_to_hub = gr.Checkbox(
318
+ label="Push Model to Hub during training",
319
+ value=False
320
+ )
321
+
322
+ with gr.Column(visible=False) as hub_settings:
323
+ username = gr.Textbox(
324
+ label="Hugging Face Username",
325
+ placeholder="e.g., huggingface-username"
326
+ )
327
+ model_name = gr.Textbox(
328
+ label="Model Name",
329
+ placeholder="e.g., bert-complaint-classifier"
330
+ )
331
+
332
+ # Post-Training Push
333
+ with gr.Group():
334
+ post_train_push = gr.Checkbox(
335
+ label="Push trained model to Hub after training",
336
+ value=False
337
+ )
338
+
339
+ with gr.Column(visible=False) as post_train_settings:
340
+ post_train_username = gr.Textbox(
341
+ label="Hugging Face Username",
342
+ placeholder="e.g., huggingface-username"
343
+ )
344
+ post_train_model_name = gr.Textbox(
345
+ label="Model Name",
346
+ placeholder="e.g., bert-complaint-classifier"
347
+ )
348
+ post_train_token = gr.Textbox(
349
+ label="Hugging Face Token (if different from above)",
350
+ type="password"
351
+ )
352
+ post_train_push_btn = gr.Button(
353
+ "Push Model to Hub",
354
+ variant="secondary"
355
+ )
356
+ post_train_status = gr.Textbox(label="Upload Status")
357
+
358
+ # Show/hide settings based on checkboxes
359
+ push_to_hub.change(
360
+ lambda x: gr.update(visible=x),
361
+ inputs=push_to_hub,
362
+ outputs=hub_settings
363
+ )
364
+
365
+ post_train_push.change(
366
+ lambda x: gr.update(visible=x),
367
+ inputs=post_train_push,
368
+ outputs=post_train_settings
369
+ )
370
+
371
+ gr.Markdown("### BERT Model Note")
372
+ gr.Markdown("⚠️ BERT has a maximum sequence length of 512 tokens. Complaints longer than this will be truncated.")
373
+
374
+ train_btn = gr.Button("Start Training", variant="primary")
375
+ training_output = gr.Textbox(label="Training Progress", lines=10)
376
+
377
+ # Connect the buttons
378
+ post_train_push_btn.click(
379
+ push_to_hub_after_training,
380
+ inputs=[
381
+ gr.Textbox(value=MODEL_PATH, visible=False),
382
+ post_train_username,
383
+ post_train_model_name,
384
+ post_train_token
385
+ ],
386
+ outputs=post_train_status
387
+ )
388
+
389
+ train_btn.click(
390
+ train_model,
391
+ inputs=[
392
+ dataset_name,
393
+ num_epochs,
394
+ batch_size,
395
+ learning_rate,
396
+ hf_token,
397
+ push_to_hub,
398
+ username,
399
+ model_name
400
+ ],
401
+ outputs=training_output,
402
+ show_progress="full"
403
+ )
404
+
405
+ # Classification Tab
406
+ with gr.TabItem("Classify Complaints"):
407
+ gr.Markdown("### Classify Customer Complaints")
408
+
409
+ model_path = gr.Textbox(
410
+ label="Model Path or Hugging Face ID",
411
+ value="local-model",
412
+ placeholder="e.g., local-model or your-username/bert-complaint-classifier"
413
+ )
414
+
415
+ with gr.Tabs():
416
+ # Single Complaint Classification
417
+ with gr.TabItem("Single Complaint"):
418
+ text_input = gr.Textbox(
419
+ label="Complaint Text",
420
+ lines=5,
421
+ placeholder="Enter a customer complaint here..."
422
+ )
423
+
424
+ classify_btn = gr.Button("Classify", variant="primary")
425
+ token_info = gr.Markdown("Note: BERT has a 512 token limit. Longer complaints will be truncated.")
426
+ text_output = gr.Textbox(label="Classification Result", lines=5)
427
+
428
+ # Token counter
429
+ def count_tokens(text):
430
+ if not text or CURRENT_TOKENIZER is None:
431
+ return "Enter text to see token count"
432
+ tokens = CURRENT_TOKENIZER(text, truncation=False)
433
+ count = len(tokens['input_ids'])
434
+ if count > 512:
435
+ return f"⚠️ **Token count: {count}/512** - Text will be truncated for BERT"
436
+ else:
437
+ return f"Token count: {count}/512"
438
+
439
+ text_input.change(
440
+ fn=count_tokens,
441
+ inputs=text_input,
442
+ outputs=token_info
443
+ )
444
+
445
+ classify_btn.click(
446
+ predict_text,
447
+ inputs=[text_input, model_path],
448
+ outputs=text_output
449
+ )
450
+
451
+ # Batch Processing
452
+ with gr.TabItem("Batch Processing"):
453
+ gr.Markdown("Upload a CSV file with a 'complaint' column")
454
+ csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
455
+ batch_classify_btn = gr.Button("Classify All", variant="primary")
456
+ csv_output = gr.Textbox(label="Classification Results", lines=15)
457
+
458
+ batch_classify_btn.click(
459
+ predict_csv,
460
+ inputs=[csv_input, model_path],
461
+ outputs=csv_output
462
+ )
463
+
464
+ # Launch the app
465
+ if __name__ == "__main__":
466
+ # Initialize tokenizer on startup
467
+ if CURRENT_TOKENIZER is None:
468
+ CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
469
+ app.launch()