prasenjeet099 commited on
Commit
5541a23
·
verified ·
1 Parent(s): 70abc44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -9
app.py CHANGED
@@ -4,7 +4,9 @@ import time
4
  import os
5
  import pandas as pd
6
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
7
- from datasets import load_dataset
 
 
8
  from tqdm import tqdm # For progress bar during training
9
 
10
  # Set up Streamlit page
@@ -20,6 +22,10 @@ hardware = st.sidebar.selectbox("Hardware", ["CPU", "Single GPU", "Multi-GPU", "
20
  model_choice = st.sidebar.selectbox("Choose Model", ["bert-base-uncased", "distilbert-base-uncased", "roberta-base"])
21
  dataset_source = st.sidebar.selectbox("Dataset Source", ["glue/sst2", "imdb", "ag_news", "Custom"])
22
 
 
 
 
 
23
  # Training Parameters
24
  epochs = st.sidebar.slider("Number of Epochs", 1, 10, 3)
25
  batch_size = st.sidebar.selectbox("Batch Size", [8, 16, 32, 64], index=1)
@@ -62,20 +68,36 @@ def train_model():
62
  model = AutoModelForSequenceClassification.from_pretrained(model_choice, num_labels=2) # Adjust num_labels as necessary
63
 
64
  # Load dataset
65
- dataset = load_dataset(dataset_source)
 
 
 
 
 
 
 
66
 
67
  # Tokenization function
68
  def tokenize_function(examples):
69
- return tokenizer(examples["text"], truncation=True, padding="max_length")
70
 
71
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
72
 
73
- # Ensure that the dataset has the correct label column (adjust the label column name if necessary)
74
- if "label" not in tokenized_datasets["train"].features:
75
- raise ValueError("Dataset does not have a 'label' column for supervised training")
76
-
77
- train_dataset = tokenized_datasets["train"]
78
- eval_dataset = tokenized_datasets["validation"]
 
 
 
 
 
 
 
 
 
79
 
80
  # Checkpoint Handling
81
  if resume_training and os.path.exists(checkpoint_path):
@@ -107,6 +129,11 @@ def train_model():
107
 
108
  # Training Loop with Progress Bar
109
  metrics = []
 
 
 
 
 
110
  with open(log_file, "w") as log_file_handle:
111
  log_file_handle.write("Starting training...\n")
112
  log_file_handle.flush()
@@ -135,12 +162,49 @@ def train_model():
135
  metrics.append({"epoch": epoch+1, "loss": results["eval_loss"], "accuracy": results.get("eval_accuracy", 0)})
136
  pd.DataFrame(metrics).to_csv(metrics_file, index=False)
137
 
 
 
 
 
 
 
 
138
  # Update logs & metrics in UI
139
  log_area.text(log_text)
140
  st.line_chart(pd.DataFrame(metrics).set_index("epoch"))
141
 
142
  time.sleep(2)
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # Start Training
145
  if start_train:
146
  train_model()
 
4
  import os
5
  import pandas as pd
6
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
7
+ from datasets import load_dataset, Dataset
8
+ from sklearn.metrics import confusion_matrix
9
+ from sklearn.model_selection import train_test_split
10
  from tqdm import tqdm # For progress bar during training
11
 
12
  # Set up Streamlit page
 
22
  model_choice = st.sidebar.selectbox("Choose Model", ["bert-base-uncased", "distilbert-base-uncased", "roberta-base"])
23
  dataset_source = st.sidebar.selectbox("Dataset Source", ["glue/sst2", "imdb", "ag_news", "Custom"])
24
 
25
+ # Column Mapping for custom datasets
26
+ text_column = st.sidebar.text_input("Text Column", "text")
27
+ label_column = st.sidebar.text_input("Label Column", "label")
28
+
29
  # Training Parameters
30
  epochs = st.sidebar.slider("Number of Epochs", 1, 10, 3)
31
  batch_size = st.sidebar.selectbox("Batch Size", [8, 16, 32, 64], index=1)
 
68
  model = AutoModelForSequenceClassification.from_pretrained(model_choice, num_labels=2) # Adjust num_labels as necessary
69
 
70
  # Load dataset
71
+ if dataset_source.lower() != "custom":
72
+ dataset = load_dataset(dataset_source)
73
+ else:
74
+ # Handle Custom Dataset
75
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
76
+ if uploaded_file is not None:
77
+ dataset_df = pd.read_csv(uploaded_file)
78
+ dataset = Dataset.from_pandas(dataset_df)
79
 
80
  # Tokenization function
81
  def tokenize_function(examples):
82
+ return tokenizer(examples[text_column], truncation=True, padding="max_length")
83
 
84
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
85
 
86
+ # Handle missing or non-standard splits
87
+ if "train" in tokenized_datasets:
88
+ train_dataset = tokenized_datasets["train"]
89
+ else:
90
+ # Create a custom split if no train split exists
91
+ train_dataset = tokenized_datasets
92
+ train_dataset, eval_dataset = train_test_split(train_dataset, test_size=0.1)
93
+
94
+ # Check for validation or test split
95
+ if "validation" in tokenized_datasets:
96
+ eval_dataset = tokenized_datasets["validation"]
97
+ elif "test" in tokenized_datasets:
98
+ eval_dataset = tokenized_datasets["test"]
99
+ else:
100
+ raise ValueError("Dataset does not have a 'validation' or 'test' split.")
101
 
102
  # Checkpoint Handling
103
  if resume_training and os.path.exists(checkpoint_path):
 
129
 
130
  # Training Loop with Progress Bar
131
  metrics = []
132
+ loss_values = [] # To store loss values for plotting
133
+ accuracy_values = [] # To store accuracy values for plotting
134
+ all_preds = [] # To store predictions for confusion matrix
135
+ all_labels = [] # To store true labels for confusion matrix
136
+
137
  with open(log_file, "w") as log_file_handle:
138
  log_file_handle.write("Starting training...\n")
139
  log_file_handle.flush()
 
162
  metrics.append({"epoch": epoch+1, "loss": results["eval_loss"], "accuracy": results.get("eval_accuracy", 0)})
163
  pd.DataFrame(metrics).to_csv(metrics_file, index=False)
164
 
165
+ loss_values.append(results["eval_loss"])
166
+ accuracy_values.append(results.get("eval_accuracy", 0))
167
+
168
+ # Collect predictions and labels for confusion matrix
169
+ all_preds.extend(results.get("logits", []))
170
+ all_labels.extend(eval_dataset["label"])
171
+
172
  # Update logs & metrics in UI
173
  log_area.text(log_text)
174
  st.line_chart(pd.DataFrame(metrics).set_index("epoch"))
175
 
176
  time.sleep(2)
177
 
178
+ # After training, plot charts for loss, accuracy, and confusion matrix
179
+ plot_metrics(loss_values, accuracy_values)
180
+ plot_confusion_matrix(all_labels, all_preds)
181
+
182
+ def plot_metrics(loss_values, accuracy_values):
183
+ # Plot Loss Curve using Streamlit chart
184
+ metrics_df = pd.DataFrame({
185
+ "Epoch": range(1, len(loss_values) + 1),
186
+ "Loss": loss_values,
187
+ "Accuracy": accuracy_values
188
+ })
189
+
190
+ st.write("### Training Loss and Accuracy Curve")
191
+ st.line_chart(metrics_df.set_index("Epoch"))
192
+
193
+ def plot_confusion_matrix(true_labels, preds):
194
+ # Convert logits to predicted class labels
195
+ pred_labels = torch.argmax(torch.tensor(preds), axis=1).numpy()
196
+
197
+ # Compute confusion matrix
198
+ cm = confusion_matrix(true_labels, pred_labels)
199
+
200
+ # Plot confusion matrix using Streamlit chart
201
+ fig, ax = plt.subplots(figsize=(8, 6))
202
+ ax = sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Class 0", "Class 1"], yticklabels=["Class 0", "Class 1"])
203
+ ax.set_title("Confusion Matrix")
204
+ ax.set_xlabel("Predicted Label")
205
+ ax.set_ylabel("True Label")
206
+ st.pyplot(fig)
207
+
208
  # Start Training
209
  if start_train:
210
  train_model()