msmaje commited on
Commit
04f0e6e
Β·
verified Β·
1 Parent(s): 3f80a9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -19
app.py CHANGED
@@ -5,20 +5,166 @@ import os
5
  import tempfile
6
  import time
7
  import subprocess
 
8
  from huggingface_hub import login, HfApi
9
  from transformers import AutoTokenizer, BertForSequenceClassification
10
- from datasets import load_dataset
11
 
12
  # Global variables
13
  MODEL_PATH = "local-model"
14
  CATEGORIES = ['Online-Safety', 'BroadBand', 'TV-Radio']
15
  idx_to_category = {0: 'Online-Safety', 1: 'BroadBand', 2: 'TV-Radio'}
 
16
  TOKEN = None
17
  TRAINING_LOGS = []
18
  CURRENT_MODEL = None
19
  CURRENT_TOKENIZER = None
20
 
21
- def login_to_hf(token):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """Login to Hugging Face"""
23
  global TOKEN
24
  TOKEN = token
@@ -162,9 +308,9 @@ def predict_csv(csv_file, model_path):
162
  except Exception as e:
163
  return f"❌ CSV processing failed: {str(e)}"
164
 
165
- def train_model(dataset_name, num_epochs, batch_size, learning_rate, hf_token,
166
- push_to_hub, username, model_name):
167
- """Start the model training process"""
168
  global TRAINING_LOGS, MODEL_PATH
169
 
170
  TRAINING_LOGS = [] # Reset logs at the start of training
@@ -184,14 +330,47 @@ def train_model(dataset_name, num_epochs, batch_size, learning_rate, hf_token,
184
  else:
185
  hub_model_id = None
186
 
187
- # Create training command
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  cmd = [
189
  "python", "bert_finetune.py",
190
- "--dataset_name", dataset_name,
191
  "--model_id", "bert-base-uncased",
192
  "--output_dir", MODEL_PATH,
193
- "--feature_column", "complaint",
194
- "--label_column", "label_idx",
195
  "--num_labels", "3",
196
  "--num_train_epochs", str(num_epochs),
197
  "--batch_size", str(batch_size),
@@ -204,7 +383,7 @@ def train_model(dataset_name, num_epochs, batch_size, learning_rate, hf_token,
204
  if hf_token:
205
  cmd.extend(["--hf_token", hf_token])
206
 
207
- TRAINING_LOGS.append(f"Starting training with command: {' '.join(cmd)}")
208
  yield "\n".join(TRAINING_LOGS)
209
 
210
  try:
@@ -216,7 +395,7 @@ def train_model(dataset_name, num_epochs, batch_size, learning_rate, hf_token,
216
  bufsize=1
217
  )
218
 
219
- TRAINING_LOGS.append("Training started...")
220
  yield "\n".join(TRAINING_LOGS)
221
 
222
  while True:
@@ -232,13 +411,21 @@ def train_model(dataset_name, num_epochs, batch_size, learning_rate, hf_token,
232
  if process.returncode == 0:
233
  TRAINING_LOGS.append("βœ… Training completed successfully!")
234
  if push_to_hub and hub_model_id:
235
- TRAINING_LOGS.append(f"βœ… Model pushed to Hugging Face Hub: {hub_model_id}")
236
 
237
  # Load the trained model
238
- TRAINING_LOGS.append("Loading trained model...")
239
  load_result = load_model(MODEL_PATH)
240
  TRAINING_LOGS.append(load_result)
241
 
 
 
 
 
 
 
 
 
242
  # Final success message
243
  TRAINING_LOGS.append("\n✨ All done! Your model is ready to use.")
244
  else:
@@ -289,14 +476,40 @@ with gr.Blocks(title="BERT Complaint Classifier") as app:
289
  with gr.Tabs():
290
  # Training Tab
291
  with gr.TabItem("Train Model"):
292
- gr.Markdown("### Train a New Model")
293
- gr.Markdown("Provide your dataset information and training parameters")
294
 
295
- dataset_name = gr.Textbox(
296
- label="Dataset Name (from Hugging Face)",
297
- placeholder="e.g., your-username/complaint-categories-dataset"
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
 
300
  with gr.Row():
301
  num_epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs")
302
  batch_size = gr.Slider(minimum=4, maximum=32, value=8, step=4, label="Batch Size")
 
5
  import tempfile
6
  import time
7
  import subprocess
8
+ import json
9
  from huggingface_hub import login, HfApi
10
  from transformers import AutoTokenizer, BertForSequenceClassification
11
+ from datasets import load_dataset, Dataset, DatasetDict
12
 
13
  # Global variables
14
  MODEL_PATH = "local-model"
15
  CATEGORIES = ['Online-Safety', 'BroadBand', 'TV-Radio']
16
  idx_to_category = {0: 'Online-Safety', 1: 'BroadBand', 2: 'TV-Radio'}
17
+ category_to_idx = {'Online-Safety': 0, 'BroadBand': 1, 'TV-Radio': 2}
18
  TOKEN = None
19
  TRAINING_LOGS = []
20
  CURRENT_MODEL = None
21
  CURRENT_TOKENIZER = None
22
 
23
+ # Local data files
24
+ LOCAL_DATA_FILES = [
25
+ "merged-test-data.csv",
26
+ "test-category.csv",
27
+ "test-complaint.csv"
28
+ ]
29
+
30
+ def get_available_datasets():
31
+ """Get list of available local datasets"""
32
+ available_files = []
33
+ for file in LOCAL_DATA_FILES:
34
+ if os.path.exists(file):
35
+ try:
36
+ df = pd.read_csv(file)
37
+ available_files.append(f"{file} ({len(df)} rows)")
38
+ except Exception as e:
39
+ available_files.append(f"{file} (Error: {str(e)})")
40
+ else:
41
+ available_files.append(f"{file} (Not found)")
42
+
43
+ # Also check for any other CSV files in the directory
44
+ for file in os.listdir("."):
45
+ if file.endswith(".csv") and file not in LOCAL_DATA_FILES:
46
+ if os.path.exists(file):
47
+ try:
48
+ df = pd.read_csv(file)
49
+ available_files.append(f"{file} ({len(df)} rows)")
50
+ except:
51
+ available_files.append(f"{file} (Error reading)")
52
+
53
+ return available_files
54
+
55
+ def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
56
+ """Load and prepare local CSV dataset for training"""
57
+ try:
58
+ if not os.path.exists(file_path):
59
+ raise FileNotFoundError(f"Dataset file not found: {file_path}")
60
+
61
+ # Load the CSV file
62
+ df = pd.read_csv(file_path)
63
+
64
+ # Verify required columns exist
65
+ if text_column not in df.columns:
66
+ available_cols = list(df.columns)
67
+ raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
68
+
69
+ if label_column not in df.columns:
70
+ available_cols = list(df.columns)
71
+ raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
72
+
73
+ # Clean the data
74
+ df = df.dropna(subset=[text_column, label_column])
75
+ df[text_column] = df[text_column].astype(str)
76
+
77
+ # Handle different label formats
78
+ if df[label_column].dtype == 'object':
79
+ # If labels are text, convert to indices
80
+ unique_labels = df[label_column].unique()
81
+ if len(unique_labels) > len(CATEGORIES):
82
+ raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
83
+
84
+ # Try to map text labels to our categories
85
+ label_mapping = {}
86
+ for label in unique_labels:
87
+ if label in category_to_idx:
88
+ label_mapping[label] = category_to_idx[label]
89
+ else:
90
+ # Auto-assign if not found
91
+ available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
92
+ if available_indices:
93
+ label_mapping[label] = min(available_indices)
94
+ else:
95
+ raise ValueError(f"Cannot map label '{label}' to available categories")
96
+
97
+ df['label_idx'] = df[label_column].map(label_mapping)
98
+ else:
99
+ # If labels are already numeric
100
+ df['label_idx'] = df[label_column].astype(int)
101
+
102
+ # Verify label indices are valid
103
+ if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
104
+ raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
105
+
106
+ # Create train/validation split
107
+ from sklearn.model_selection import train_test_split
108
+
109
+ train_df, val_df = train_test_split(
110
+ df,
111
+ test_size=test_size,
112
+ random_state=42,
113
+ stratify=df['label_idx']
114
+ )
115
+
116
+ # Convert to Hugging Face datasets
117
+ train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
118
+ val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
119
+
120
+ dataset_dict = DatasetDict({
121
+ 'train': train_dataset,
122
+ 'validation': val_dataset
123
+ })
124
+
125
+ return dataset_dict, text_column, 'label_idx'
126
+
127
+ except Exception as e:
128
+ raise Exception(f"Error loading dataset: {str(e)}")
129
+
130
+ def preview_dataset(file_path, text_column, label_column):
131
+ """Preview a dataset file"""
132
+ try:
133
+ if not file_path:
134
+ return "Please select a dataset file first."
135
+
136
+ if not os.path.exists(file_path):
137
+ return f"❌ File not found: {file_path}"
138
+
139
+ df = pd.read_csv(file_path)
140
+
141
+ preview_info = []
142
+ preview_info.append(f"πŸ“Š **Dataset Preview: {file_path}**")
143
+ preview_info.append(f"- **Total rows:** {len(df)}")
144
+ preview_info.append(f"- **Columns:** {list(df.columns)}")
145
+ preview_info.append("")
146
+
147
+ if text_column in df.columns:
148
+ preview_info.append(f"βœ… **Text column '{text_column}' found**")
149
+ preview_info.append(f"- Sample text: {str(df[text_column].iloc[0])[:100]}...")
150
+ else:
151
+ preview_info.append(f"❌ **Text column '{text_column}' not found**")
152
+ return "\n".join(preview_info)
153
+
154
+ if label_column in df.columns:
155
+ preview_info.append(f"βœ… **Label column '{label_column}' found**")
156
+ label_counts = df[label_column].value_counts()
157
+ preview_info.append("- **Label distribution:**")
158
+ for label, count in label_counts.items():
159
+ preview_info.append(f" - {label}: {count} ({count/len(df)*100:.1f}%)")
160
+ else:
161
+ preview_info.append(f"❌ **Label column '{label_column}' not found**")
162
+ return "\n".join(preview_info)
163
+
164
+ return "\n".join(preview_info)
165
+
166
+ except Exception as e:
167
+ return f"❌ Error previewing dataset: {str(e)}"
168
  """Login to Hugging Face"""
169
  global TOKEN
170
  TOKEN = token
 
308
  except Exception as e:
309
  return f"❌ CSV processing failed: {str(e)}"
310
 
311
+ def train_model(dataset_file, text_column, label_column, num_epochs, batch_size,
312
+ learning_rate, hf_token, push_to_hub, username, model_name):
313
+ """Start the model training process with local data"""
314
  global TRAINING_LOGS, MODEL_PATH
315
 
316
  TRAINING_LOGS = [] # Reset logs at the start of training
 
330
  else:
331
  hub_model_id = None
332
 
333
+ # Validate dataset file
334
+ if not dataset_file or not os.path.exists(dataset_file):
335
+ TRAINING_LOGS.append(f"❌ Dataset file not found: {dataset_file}")
336
+ yield "\n".join(TRAINING_LOGS)
337
+ return
338
+
339
+ try:
340
+ # Load and prepare the dataset
341
+ TRAINING_LOGS.append(f"πŸ“Š Loading dataset from {dataset_file}...")
342
+ yield "\n".join(TRAINING_LOGS)
343
+
344
+ dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset(
345
+ dataset_file, text_column, label_column
346
+ )
347
+
348
+ TRAINING_LOGS.append(f"βœ… Dataset loaded successfully!")
349
+ TRAINING_LOGS.append(f"- Train samples: {len(dataset_dict['train'])}")
350
+ TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
351
+ yield "\n".join(TRAINING_LOGS)
352
+
353
+ # Save dataset temporarily for the training script
354
+ temp_dataset_path = "temp_dataset"
355
+ os.makedirs(temp_dataset_path, exist_ok=True)
356
+ dataset_dict.save_to_disk(temp_dataset_path)
357
+
358
+ TRAINING_LOGS.append("πŸ’Ύ Dataset prepared for training...")
359
+ yield "\n".join(TRAINING_LOGS)
360
+
361
+ except Exception as e:
362
+ TRAINING_LOGS.append(f"❌ Error preparing dataset: {str(e)}")
363
+ yield "\n".join(TRAINING_LOGS)
364
+ return
365
+
366
+ # Create training command for local dataset
367
  cmd = [
368
  "python", "bert_finetune.py",
369
+ "--dataset_path", temp_dataset_path, # Use local path instead of HF dataset name
370
  "--model_id", "bert-base-uncased",
371
  "--output_dir", MODEL_PATH,
372
+ "--feature_column", final_text_col,
373
+ "--label_column", final_label_col,
374
  "--num_labels", "3",
375
  "--num_train_epochs", str(num_epochs),
376
  "--batch_size", str(batch_size),
 
383
  if hf_token:
384
  cmd.extend(["--hf_token", hf_token])
385
 
386
+ TRAINING_LOGS.append(f"πŸš€ Starting training with command: {' '.join(cmd)}")
387
  yield "\n".join(TRAINING_LOGS)
388
 
389
  try:
 
395
  bufsize=1
396
  )
397
 
398
+ TRAINING_LOGS.append("πŸ”„ Training started...")
399
  yield "\n".join(TRAINING_LOGS)
400
 
401
  while True:
 
411
  if process.returncode == 0:
412
  TRAINING_LOGS.append("βœ… Training completed successfully!")
413
  if push_to_hub and hub_model_id:
414
+ TRAINING_LOGS.append(f"πŸ€— Model pushed to Hugging Face Hub: {hub_model_id}")
415
 
416
  # Load the trained model
417
+ TRAINING_LOGS.append("πŸ“₯ Loading trained model...")
418
  load_result = load_model(MODEL_PATH)
419
  TRAINING_LOGS.append(load_result)
420
 
421
+ # Clean up temporary files
422
+ import shutil
423
+ try:
424
+ shutil.rmtree(temp_dataset_path)
425
+ TRAINING_LOGS.append("🧹 Cleaned up temporary files")
426
+ except:
427
+ pass
428
+
429
  # Final success message
430
  TRAINING_LOGS.append("\n✨ All done! Your model is ready to use.")
431
  else:
 
476
  with gr.Tabs():
477
  # Training Tab
478
  with gr.TabItem("Train Model"):
479
+ gr.Markdown("### Train a New Model with Local Data")
480
+ gr.Markdown("Select your local CSV file and configure training parameters")
481
 
482
+ # Dataset selection and preview
483
+ with gr.Row():
484
+ with gr.Column(scale=2):
485
+ dataset_file = gr.Dropdown(
486
+ label="Select Dataset File",
487
+ choices=[f for f in os.listdir(".") if f.endswith(".csv")],
488
+ value=LOCAL_DATA_FILES[0] if LOCAL_DATA_FILES[0] in os.listdir(".") else None,
489
+ allow_custom_value=True
490
+ )
491
+
492
+ with gr.Column(scale=1):
493
+ refresh_btn = gr.Button("πŸ”„ Refresh Files", size="sm")
494
+
495
+ # Column configuration
496
+ with gr.Row():
497
+ text_column = gr.Textbox(
498
+ label="Text Column Name",
499
+ value="complaint",
500
+ placeholder="e.g., complaint, text, description"
501
+ )
502
+ label_column = gr.Textbox(
503
+ label="Label Column Name",
504
+ value="category",
505
+ placeholder="e.g., category, label, class"
506
+ )
507
+
508
+ # Dataset preview
509
+ preview_btn = gr.Button("πŸ“Š Preview Dataset", variant="secondary")
510
+ dataset_preview = gr.Markdown("Select a dataset file and click 'Preview Dataset' to see its structure.")
511
 
512
+ # Training parameters
513
  with gr.Row():
514
  num_epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs")
515
  batch_size = gr.Slider(minimum=4, maximum=32, value=8, step=4, label="Batch Size")