robzjgman commited on
Commit
257f890
·
verified ·
1 Parent(s): b075a4c

Upload 4 files

Browse files
6 _ Fine-Tuning (Gemma)/Specific Models/LLM trained Gemma Model/gemini_delivery_model.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torch import nn
5
+ from transformers import AutoTokenizer, GemmaModel
6
+ from peft import LoraConfig, get_peft_model, TaskType
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.metrics import classification_report, hamming_loss, accuracy_score, precision_score, recall_score, f1_score
9
+ import numpy as np
10
+ import random
11
+ import matplotlib.pyplot as plt
12
+ import os
13
+
14
+ # For UTF-8 characters in output
15
+ import sys
16
+ sys.stdout.reconfigure(encoding='utf-8')
17
+
18
+ # Set random seeds for reproducibility
19
+ seed_value = 42
20
+ random.seed(seed_value)
21
+ np.random.seed(seed_value)
22
+ torch.manual_seed(seed_value)
23
+ if torch.cuda.is_available():
24
+ torch.cuda.manual_seed_all(seed_value)
25
+
26
+ # Parameters
27
+ MODEL_ID = 'google/gemma-3-1b-pt'
28
+ BATCH_SIZE = 8
29
+ EPOCHS = 10
30
+ LR = 5e-5
31
+
32
+ # Load data - delivery-specific
33
+ print("Loading training data from delivery_train_dataset.csv...")
34
+ train_df = pd.read_csv('datasets/gemini/delivery_train_dataset.csv')
35
+ print("Loading test data from Test_delivery_dataset.csv...")
36
+ test_df = pd.read_csv('datasets/test_delivery_dataset.csv')
37
+
38
+ # Define label columns (Delivery sub-aspects)
39
+ label_cols = [
40
+ 'Condition_DEL',
41
+ 'Correctness_DEL',
42
+ 'Timeliness_DEL',
43
+ 'General_DEL'
44
+ ]
45
+
46
+ # Prepare training data with 80/20 train/validation split
47
+ train_X_full = train_df['Review'].astype(str).tolist()
48
+ train_Y_full = train_df[label_cols].values.astype(np.float32)
49
+
50
+ train_X, val_X, train_Y, val_Y = train_test_split(
51
+ train_X_full, train_Y_full,
52
+ test_size=0.2,
53
+ random_state=42
54
+ )
55
+
56
+ # Prepare test data
57
+ test_X = test_df['Review'].astype(str).tolist()
58
+ test_Y = test_df[label_cols].values.astype(np.float32)
59
+
60
+ print(f"\nDataset sizes:")
61
+ print(f"Training samples: {len(train_X)}")
62
+ print(f"Validation samples: {len(val_X)}")
63
+ print(f"Test samples: {len(test_X)}")
64
+ print(f"Number of labels: {len(label_cols)}")
65
+
66
+ # Compute class weights for imbalanced dataset
67
+ def compute_class_weights(labels, label_names):
68
+ """
69
+ Compute class weights for multi-label classification
70
+ using the inverse of class frequency.
71
+
72
+ Args:
73
+ labels: numpy array of shape (n_samples, n_labels)
74
+ label_names: list of label column names
75
+
76
+ Returns:
77
+ pos_weight: torch tensor of positive class weights
78
+ """
79
+ n_samples = labels.shape[0]
80
+ n_labels = labels.shape[1]
81
+
82
+ pos_weights = []
83
+
84
+ print("\n" + "="*60)
85
+ print("CLASS IMBALANCE ANALYSIS")
86
+ print("="*60)
87
+
88
+ for i, label_name in enumerate(label_names):
89
+ pos_count = np.sum(labels[:, i] == 1)
90
+ neg_count = np.sum(labels[:, i] == 0)
91
+
92
+ # Calculate positive class weight (ratio of negative to positive)
93
+ if pos_count > 0:
94
+ raw_ratio = neg_count / pos_count
95
+ # Apply square root dampening to avoid extreme weights
96
+ pos_weight = np.sqrt(raw_ratio)
97
+ else:
98
+ pos_weight = 1.0
99
+
100
+ pos_weights.append(pos_weight)
101
+
102
+ print(f"\n{label_name}:")
103
+ print(f" Positive samples: {pos_count} ({pos_count/n_samples*100:.2f}%)")
104
+ print(f" Negative samples: {neg_count} ({neg_count/n_samples*100:.2f}%)")
105
+ print(f" Raw imbalance ratio (neg/pos): {neg_count/pos_count if pos_count > 0 else 1.0:.4f}")
106
+ print(f" Dampened weight (sqrt of ratio): {pos_weight:.4f}")
107
+
108
+ print("="*60 + "\n")
109
+
110
+ return torch.FloatTensor(pos_weights)
111
+
112
+ def find_optimal_thresholds(model, dataloader, label_cols, device):
113
+ """
114
+ Find optimal decision threshold for each class independently
115
+ by maximizing F1-score on the validation set.
116
+
117
+ Args:
118
+ model: trained model
119
+ dataloader: validation data loader
120
+ label_cols: list of label column names
121
+ device: torch device
122
+
123
+ Returns:
124
+ optimal_thresholds: numpy array of optimal thresholds for each class
125
+ """
126
+ from sklearn.metrics import f1_score
127
+
128
+ print("\n" + "="*60)
129
+ print("OPTIMIZING DECISION THRESHOLDS")
130
+ print("="*60)
131
+
132
+ # Collect all predictions and labels
133
+ model.eval()
134
+ all_probs = []
135
+ all_labels = []
136
+
137
+ with torch.no_grad():
138
+ for input_ids, attention_mask, labels in dataloader:
139
+ input_ids = input_ids.to(device)
140
+ attention_mask = attention_mask.to(device)
141
+ logits = model(input_ids, attention_mask)
142
+ probs = torch.sigmoid(logits).cpu().numpy()
143
+ all_probs.append(probs)
144
+ all_labels.append(labels.cpu().numpy())
145
+
146
+ all_probs = np.vstack(all_probs)
147
+ all_labels = np.vstack(all_labels)
148
+
149
+ # Find optimal threshold for each class
150
+ optimal_thresholds = []
151
+ threshold_range = np.arange(0.1, 0.91, 0.05) # 0.1 to 0.9 in steps of 0.05
152
+
153
+ for i, label_name in enumerate(label_cols):
154
+ best_threshold = 0.5
155
+ best_f1 = 0.0
156
+
157
+ for threshold in threshold_range:
158
+ preds = (all_probs[:, i] > threshold).astype(int)
159
+ f1 = f1_score(all_labels[:, i], preds, zero_division=0)
160
+
161
+ if f1 > best_f1:
162
+ best_f1 = f1
163
+ best_threshold = threshold
164
+
165
+ optimal_thresholds.append(best_threshold)
166
+ print(f"\n{label_name}:")
167
+ print(f" Optimal threshold: {best_threshold:.2f}")
168
+ print(f" Best F1-score: {best_f1:.4f}")
169
+ print(f" (Default 0.5 threshold F1: {f1_score(all_labels[:, i], (all_probs[:, i] > 0.5).astype(int), zero_division=0):.4f})")
170
+
171
+ print("="*60 + "\n")
172
+
173
+ return np.array(optimal_thresholds)
174
+
175
+ def predict_with_thresholds(model, dataloader, thresholds, device):
176
+ """
177
+ Make predictions using custom thresholds for each class.
178
+
179
+ Args:
180
+ model: trained model
181
+ dataloader: data loader
182
+ thresholds: numpy array of thresholds for each class
183
+ device: torch device
184
+
185
+ Returns:
186
+ predictions: numpy array of predictions
187
+ labels: numpy array of true labels
188
+ """
189
+ model.eval()
190
+ all_preds = []
191
+ all_labels = []
192
+
193
+ with torch.no_grad():
194
+ for input_ids, attention_mask, labels in dataloader:
195
+ input_ids = input_ids.to(device)
196
+ attention_mask = attention_mask.to(device)
197
+ logits = model(input_ids, attention_mask)
198
+ probs = torch.sigmoid(logits).cpu().numpy()
199
+
200
+ # Apply custom thresholds for each class
201
+ preds = np.zeros_like(probs, dtype=int)
202
+ for i in range(len(thresholds)):
203
+ preds[:, i] = (probs[:, i] > thresholds[i]).astype(int)
204
+
205
+ all_preds.append(preds)
206
+ all_labels.append(labels.cpu().numpy())
207
+
208
+ return np.vstack(all_preds), np.vstack(all_labels)
209
+
210
+ # Dataset class
211
+ class ReviewDataset(Dataset):
212
+ def __init__(self, texts, labels):
213
+ self.texts = texts
214
+ self.labels = labels
215
+
216
+ def __len__(self):
217
+ return len(self.texts)
218
+
219
+ def __getitem__(self, idx):
220
+ encoding = tokenizer(
221
+ self.texts[idx],
222
+ padding='max_length',
223
+ truncation=True,
224
+ max_length=256,
225
+ return_tensors='pt'
226
+ )
227
+ input_ids = encoding['input_ids'].squeeze()
228
+ attention_mask = encoding['attention_mask'].squeeze()
229
+ label = torch.FloatTensor(self.labels[idx])
230
+ return input_ids, attention_mask, label
231
+
232
+ # Initialize tokenizer
233
+ print("\nInitializing tokenizer...")
234
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=True)
235
+
236
+ # Create datasets
237
+ train_dataset = ReviewDataset(train_X, train_Y)
238
+ val_dataset = ReviewDataset(val_X, val_Y)
239
+ test_dataset = ReviewDataset(test_X, test_Y)
240
+
241
+ # Create data loaders
242
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
243
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
244
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
245
+
246
+ # Compute class weights based on training data
247
+ print("Computing class weights for imbalanced dataset...")
248
+ pos_weights = compute_class_weights(train_Y, label_cols)
249
+
250
+ # Initialize model with LoRA
251
+ print("Initializing model with LoRA...")
252
+ backbone = GemmaModel.from_pretrained(MODEL_ID, token=True, dtype=torch.bfloat16)
253
+
254
+ lora_config = LoraConfig(
255
+ task_type=TaskType.FEATURE_EXTRACTION,
256
+ r=8,
257
+ lora_alpha=16,
258
+ lora_dropout=0.05,
259
+ target_modules=["q_proj", "v_proj"]
260
+ )
261
+ backbone = get_peft_model(backbone, lora_config)
262
+
263
+ # Classifier model
264
+ class GemmaClassifier(nn.Module):
265
+ def __init__(self, backbone, num_labels):
266
+ super().__init__()
267
+ self.backbone = backbone
268
+ self.pooler = nn.AdaptiveAvgPool1d(1)
269
+ self.classifier = nn.Linear(backbone.config.hidden_size, num_labels)
270
+
271
+ def forward(self, input_ids, attention_mask):
272
+ output = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
273
+ hidden = output.last_hidden_state
274
+ pooled = self.pooler(hidden.permute(0, 2, 1)).squeeze(-1)
275
+ logits = self.classifier(pooled.float())
276
+ return logits
277
+
278
+ # Initialize model, optimizer, and loss function
279
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
280
+ print(f"Using device: {device}")
281
+
282
+ model = GemmaClassifier(backbone, len(label_cols)).to(device)
283
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
284
+ # Use computed pos_weight to handle class imbalance
285
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights.to(device))
286
+ print(f"\nInitialized BCEWithLogitsLoss with pos_weight: {pos_weights.cpu().numpy()}")
287
+
288
+ # Initialize loss tracking
289
+ train_losses = []
290
+ val_losses = []
291
+ train_batch_losses = [] # Per-batch training losses
292
+ val_batch_losses = [] # Per-batch validation losses
293
+
294
+ # Early stopping variables
295
+ best_val_loss = float('inf')
296
+ best_epoch = 0
297
+ best_model_state = None
298
+ patience = 5 # Number of epochs to wait for improvement
299
+ patience_counter = 0
300
+
301
+ # Training loop
302
+ print("\n" + "="*60)
303
+ print("TRAINING")
304
+ print("="*60)
305
+
306
+ for epoch in range(EPOCHS):
307
+ model.train()
308
+ total_loss = 0
309
+ batch_count = 0
310
+
311
+ for input_ids, attention_mask, labels in train_loader:
312
+ input_ids = input_ids.to(device)
313
+ attention_mask = attention_mask.to(device)
314
+ labels = labels.to(device)
315
+
316
+ optimizer.zero_grad()
317
+ logits = model(input_ids, attention_mask)
318
+ loss = criterion(logits, labels)
319
+ loss.backward()
320
+ optimizer.step()
321
+
322
+ total_loss += loss.item()
323
+ batch_count += 1
324
+ train_batch_losses.append(loss.item()) # Store per-batch loss
325
+
326
+ # Print progress every 100 batches
327
+ if batch_count % 100 == 0:
328
+ print(f" Epoch {epoch+1} | Batch {batch_count}/{len(train_loader)} | Current Loss: {loss.item():.4f}")
329
+
330
+ avg_train_loss = total_loss / len(train_loader)
331
+ train_losses.append(avg_train_loss)
332
+ print(f"\nEpoch {epoch+1}/{EPOCHS} completed")
333
+ print(f"Average Training Loss: {avg_train_loss:.4f}")
334
+
335
+ # Validation on validation set
336
+ model.eval()
337
+ val_loss = 0
338
+ with torch.no_grad():
339
+ for input_ids, attention_mask, labels in val_loader:
340
+ input_ids = input_ids.to(device)
341
+ attention_mask = attention_mask.to(device)
342
+ labels = labels.to(device)
343
+
344
+ logits = model(input_ids, attention_mask)
345
+ loss = criterion(logits, labels)
346
+ val_loss += loss.item()
347
+ val_batch_losses.append(loss.item()) # Store per-batch validation loss
348
+
349
+ avg_val_loss = val_loss / len(val_loader)
350
+ val_losses.append(avg_val_loss)
351
+ print(f"Validation Loss: {avg_val_loss:.4f}")
352
+
353
+ # Early stopping check
354
+ if avg_val_loss < best_val_loss:
355
+ best_val_loss = avg_val_loss
356
+ best_epoch = epoch + 1
357
+ best_model_state = model.state_dict().copy()
358
+ patience_counter = 0
359
+ print(f"✓ New best validation loss: {best_val_loss:.4f} (Epoch {best_epoch})")
360
+ else:
361
+ patience_counter += 1
362
+ print(f" No improvement for {patience_counter} epoch(s)")
363
+ if patience_counter >= patience:
364
+ print(f"\nEarly stopping triggered! Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}")
365
+ break
366
+
367
+ print("-" * 60)
368
+
369
+ # Load best model state
370
+ if best_model_state is not None:
371
+ print(f"\nLoading best model from epoch {best_epoch} with validation loss: {best_val_loss:.4f}")
372
+ model.load_state_dict(best_model_state)
373
+ else:
374
+ print("\nNo best model found, using final model state")
375
+
376
+ # Optimize decision thresholds using validation set
377
+ print("Finding optimal decision thresholds for each class...")
378
+ optimal_thresholds = find_optimal_thresholds(model, val_loader, label_cols, device)
379
+ print(f"Optimal thresholds: {optimal_thresholds}")
380
+
381
+ # SAVE MODEL AFTER TRAINING
382
+ # SAVE_PATH = "gemma_delivery_specific.pt"
383
+ # torch.save(model.state_dict(), SAVE_PATH)
384
+ # print(f"\nModel saved to: {SAVE_PATH}")
385
+ SAVE_DIR = r"C:\temp\new_models" # make sure this folder exists
386
+ os.makedirs(SAVE_DIR, exist_ok=True)
387
+ SAVE_PATH = os.path.join(SAVE_DIR, "gemma_delivery_specific.pt")
388
+ torch.save(model.to('cpu').state_dict(), SAVE_PATH)
389
+ model.to(device) # Move model back to device after saving
390
+ print(f"\nModel saved to: {SAVE_PATH}")
391
+
392
+ # Plot training and validation loss
393
+ print("\n" + "="*60)
394
+ print("PLOTTING TRAINING CURVES")
395
+ print("="*60)
396
+
397
+ plt.figure(figsize=(10, 6))
398
+ epochs_range = range(1, EPOCHS + 1)
399
+
400
+ plt.plot(epochs_range, train_losses, 'b-o', label='Training Loss', linewidth=2, markersize=8)
401
+ plt.plot(epochs_range, val_losses, 'r-s', label='Validation Loss', linewidth=2, markersize=8)
402
+
403
+ plt.xlabel('Epoch', fontsize=12)
404
+ plt.ylabel('Loss', fontsize=12)
405
+ plt.title('Training and Validation Loss Over Epochs', fontsize=14, fontweight='bold')
406
+ plt.legend(fontsize=10)
407
+ plt.grid(True, alpha=0.3)
408
+ plt.tight_layout()
409
+
410
+ # Save the plot
411
+ plot_path = 'training_loss_plot_delivery.png'
412
+ plt.savefig(plot_path, dpi=300, bbox_inches='tight')
413
+ print(f"Training loss plot saved to: {plot_path}")
414
+
415
+ # Display loss values
416
+ print("\nLoss values per epoch:")
417
+ print("-" * 40)
418
+ for i, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses), 1):
419
+ print(f"Epoch {i}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
420
+ print("-" * 40)
421
+
422
+ # Plot detailed per-batch loss curves
423
+ print("\nGenerating detailed per-batch loss plot...")
424
+
425
+ # Create figure with two subplots
426
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
427
+
428
+ # Calculate moving average for smoothing (window size = 50 batches)
429
+ def moving_average(data, window_size):
430
+ if len(data) < window_size:
431
+ window_size = max(1, len(data) // 2)
432
+ cumsum = np.cumsum(np.insert(data, 0, 0))
433
+ return (cumsum[window_size:] - cumsum[:-window_size]) / window_size
434
+
435
+ train_ma = moving_average(train_batch_losses, 50)
436
+ val_ma = moving_average(val_batch_losses, 50)
437
+
438
+ # Subplot 1: Training loss per batch
439
+ ax1.plot(train_batch_losses, alpha=0.3, color='lightblue', linewidth=0.5, label='Raw Training Loss')
440
+ ax1.plot(range(len(train_ma)), train_ma, color='blue', linewidth=2, label='Smoothed (Moving Avg, window=50)')
441
+ ax1.set_xlabel('Training Batch', fontsize=11)
442
+ ax1.set_ylabel('Loss', fontsize=11)
443
+ ax1.set_title('Training Loss per Batch (Detailed View)', fontsize=13, fontweight='bold')
444
+ ax1.legend(fontsize=9)
445
+ ax1.grid(True, alpha=0.3)
446
+
447
+ # Add vertical lines for epoch boundaries
448
+ batches_per_epoch = len(train_loader)
449
+ for epoch_idx in range(1, EPOCHS):
450
+ ax1.axvline(x=epoch_idx * batches_per_epoch, color='red', linestyle='--', linewidth=1, alpha=0.5)
451
+
452
+ # Subplot 2: Validation loss per batch
453
+ ax2.plot(val_batch_losses, alpha=0.3, color='lightcoral', linewidth=0.5, label='Raw Validation Loss')
454
+ ax2.plot(range(len(val_ma)), val_ma, color='red', linewidth=2, label='Smoothed (Moving Avg, window=50)')
455
+ ax2.set_xlabel('Validation Batch', fontsize=11)
456
+ ax2.set_ylabel('Loss', fontsize=11)
457
+ ax2.set_title('Validation Loss per Batch (Detailed View)', fontsize=13, fontweight='bold')
458
+ ax2.legend(fontsize=9)
459
+ ax2.grid(True, alpha=0.3)
460
+
461
+ # Add vertical lines for epoch boundaries
462
+ val_batches_per_epoch = len(val_loader)
463
+ for epoch_idx in range(1, EPOCHS):
464
+ ax2.axvline(x=epoch_idx * val_batches_per_epoch, color='blue', linestyle='--', linewidth=1, alpha=0.5)
465
+
466
+ plt.tight_layout()
467
+
468
+ # Save the detailed plot
469
+ detailed_plot_path = 'training_loss_per_batch_detailed_delivery.png'
470
+ plt.savefig(detailed_plot_path, dpi=300, bbox_inches='tight')
471
+ print(f"Detailed per-batch loss plot saved to: {detailed_plot_path}")
472
+
473
+ # Print batch loss statistics
474
+ print("\nBatch Loss Statistics:")
475
+ print("-" * 60)
476
+ print(f"Training batches: {len(train_batch_losses)}")
477
+ print(f" Min loss: {min(train_batch_losses):.4f}")
478
+ print(f" Max loss: {max(train_batch_losses):.4f}")
479
+ print(f" Mean loss: {np.mean(train_batch_losses):.4f}")
480
+ print(f" Std dev: {np.std(train_batch_losses):.4f}")
481
+ print(f"\nValidation batches: {len(val_batch_losses)}")
482
+ print(f" Min loss: {min(val_batch_losses):.4f}")
483
+ print(f" Max loss: {max(val_batch_losses):.4f}")
484
+ print(f" Mean loss: {np.mean(val_batch_losses):.4f}")
485
+ print(f" Std dev: {np.std(val_batch_losses):.4f}")
486
+ print("-" * 60)
487
+
488
+ # VALIDATION SET EVALUATION (WITH OPTIMIZED THRESHOLDS)
489
+ print("\n" + "="*60)
490
+ print("VALIDATION SET EVALUATION (WITH OPTIMIZED THRESHOLDS)")
491
+ print("="*60)
492
+
493
+ val_preds, val_labels_eval = predict_with_thresholds(model, val_loader, optimal_thresholds, device)
494
+
495
+ # Also get predictions with default threshold for comparison
496
+ model.eval()
497
+ val_preds_default = []
498
+ with torch.no_grad():
499
+ for input_ids, attention_mask, labels in val_loader:
500
+ input_ids = input_ids.to(device)
501
+ attention_mask = attention_mask.to(device)
502
+ logits = model(input_ids, attention_mask)
503
+ probs = torch.sigmoid(logits).cpu().numpy()
504
+ preds = (probs > 0.5).astype(int)
505
+ val_preds_default.append(preds)
506
+
507
+ val_preds_default = np.vstack(val_preds_default)
508
+
509
+ print(f"\nPredicted data shape: {val_preds.shape}")
510
+ print(f"Ground truth data shape: {val_labels_eval.shape}")
511
+
512
+ # Comparison: Default vs Optimized Thresholds
513
+ print("\n" + "="*60)
514
+ print("COMPARISON: Default vs Optimized Thresholds")
515
+ print("="*60)
516
+
517
+ print("\nDefault Threshold (0.5):")
518
+ for i, label in enumerate(label_cols):
519
+ f1_default = f1_score(val_labels_eval[:, i], val_preds_default[:, i], zero_division=0)
520
+ print(f" {label}: F1 = {f1_default:.4f}")
521
+
522
+ print("\nOptimized Thresholds:")
523
+ for i, label in enumerate(label_cols):
524
+ f1_optimized = f1_score(val_labels_eval[:, i], val_preds[:, i], zero_division=0)
525
+ print(f" {label}: F1 = {f1_optimized:.4f} (threshold = {optimal_thresholds[i]:.2f})")
526
+ print("="*60 + "\n")
527
+
528
+ # Classification Report
529
+ print('\n' + '='*60)
530
+ print('CLASSIFICATION REPORT (VALIDATION)')
531
+ print('='*60)
532
+ print(classification_report(val_labels_eval, val_preds, target_names=label_cols))
533
+
534
+ # Hamming Loss
535
+ val_hamming_loss = hamming_loss(val_labels_eval, val_preds)
536
+ print("="*60)
537
+ print("HAMMING LOSS (Multi-label Error Rate)")
538
+ print("="*60)
539
+ print(f"Hamming Loss: {val_hamming_loss:.4f}")
540
+ print(f"(Fraction of incorrectly predicted labels: {val_hamming_loss:.2%})")
541
+
542
+ # Per-aspect metrics
543
+ print("\n" + "="*60)
544
+ print("PER-ASPECT METRICS (VALIDATION)")
545
+ print("="*60)
546
+
547
+ for i, aspect in enumerate(label_cols):
548
+ y_true = val_labels_eval[:, i]
549
+ y_pred = val_preds[:, i]
550
+
551
+ acc = accuracy_score(y_true, y_pred)
552
+ prec = precision_score(y_true, y_pred, zero_division=0)
553
+ rec = recall_score(y_true, y_pred, zero_division=0)
554
+ f1 = f1_score(y_true, y_pred, zero_division=0)
555
+
556
+ print(f"\n=== {aspect.upper()} ===")
557
+ print(f"Accuracy: {acc:.4f}")
558
+ print(f"Precision: {prec:.4f}")
559
+ print(f"Recall: {rec:.4f}")
560
+ print(f"F1 Score: {f1:.4f}")
561
+
562
+ tp = np.sum((y_true == 1) & (y_pred == 1))
563
+ tn = np.sum((y_true == 0) & (y_pred == 0))
564
+ fp = np.sum((y_true == 0) & (y_pred == 1))
565
+ fn = np.sum((y_true == 1) & (y_pred == 0))
566
+
567
+ print(f" TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
568
+
569
+ # Exact match accuracy
570
+ val_exact_matches = np.all(val_preds == val_labels_eval, axis=1)
571
+ val_exact_match_acc = np.mean(val_exact_matches)
572
+
573
+ print("\n" + "="*60)
574
+ print("EXACT MATCH (ALL ASPECTS)")
575
+ print("="*60)
576
+ print(f"Samples with ALL aspects correct: {np.sum(val_exact_matches)}/{len(val_exact_matches)}")
577
+ print(f"Exact Match Accuracy: {val_exact_match_acc:.4f}")
578
+
579
+ # Partial match accuracy (per sample)
580
+ partial_match_scores = []
581
+ for i in range(len(val_labels_eval)):
582
+ correct_labels = np.sum(val_preds[i] == val_labels_eval[i])
583
+ partial_match_scores.append(correct_labels / len(label_cols))
584
+
585
+ partial_match_scores = np.array(partial_match_scores)
586
+ avg_partial_match = np.mean(partial_match_scores)
587
+
588
+ print("\n" + "="*60)
589
+ print("PARTIAL MATCH (PER-SAMPLE LABEL ACCURACY)")
590
+ print("="*60)
591
+ print(f"Average Partial Match: {avg_partial_match:.4f} ({avg_partial_match:.2%})")
592
+ print(f"(Average fraction of labels correctly predicted per sample)")
593
+
594
+ # Sample predictions with match/mismatch
595
+ print("\n" + "="*60)
596
+ print("SAMPLE PREDICTIONS VS GROUND TRUTH (VALIDATION)")
597
+ print("="*60)
598
+
599
+ num_samples = min(10, len(val_X))
600
+ print(f"\nShowing {num_samples} validation samples:\n")
601
+
602
+ for idx in range(num_samples):
603
+ review = val_X[idx]
604
+ true_labels = [label_cols[i] for i, v in enumerate(val_labels_eval[idx]) if v == 1]
605
+ pred_labels = [label_cols[i] for i, v in enumerate(val_preds[idx]) if v == 1]
606
+
607
+ # Calculate partial match for this sample
608
+ # Count how many true labels were correctly predicted
609
+ matching_labels = len(set(true_labels) & set(pred_labels))
610
+ total_true_labels = len(true_labels) if len(true_labels) > 0 else 1
611
+ partial_match = matching_labels / total_true_labels
612
+
613
+ review_display = review[:150] + "..." if len(review) > 150 else review
614
+ print(f"Sample {idx + 1}:")
615
+ print(f"Review: {review_display}")
616
+ print(f"✓ True Labels: {true_labels if true_labels else ['None']}")
617
+ print(f"→ Predicted Labels: {pred_labels if pred_labels else ['None']}")
618
+ print(f"Match: {'✓ Exact' if set(true_labels) == set(pred_labels) else '✗ Mismatch'}")
619
+ print(f"Partial Match: {matching_labels}/{total_true_labels} labels correct ({partial_match:.2%})")
620
+ print("-" * 40)
621
+
622
+ # Final Evaluation on Test Set (WITH OPTIMIZED THRESHOLDS)
623
+ print("\n" + "="*60)
624
+ print("FINAL EVALUATION ON TEST SET (WITH OPTIMIZED THRESHOLDS)")
625
+ print("="*60)
626
+
627
+ all_preds, all_labels = predict_with_thresholds(model, test_loader, optimal_thresholds, device)
628
+
629
+ print(f"\nPredicted data shape: {all_preds.shape}")
630
+ print(f"Ground truth data shape: {all_labels.shape}")
631
+
632
+ # Classification Report
633
+ print('\n' + '='*60)
634
+ print('CLASSIFICATION REPORT')
635
+ print('='*60)
636
+ print(classification_report(all_labels, all_preds, target_names=label_cols))
637
+
638
+ # Hamming Loss
639
+ hamming_loss_value = hamming_loss(all_labels, all_preds)
640
+ print("="*60)
641
+ print("HAMMING LOSS (Multi-label Error Rate)")
642
+ print("="*60)
643
+ print(f"Hamming Loss: {hamming_loss_value:.4f}")
644
+ print(f"(Fraction of incorrectly predicted labels: {hamming_loss_value:.2%})")
645
+
646
+ # Per-aspect metrics
647
+ print("\n" + "="*60)
648
+ print("PER-ASPECT METRICS")
649
+ print("="*60)
650
+
651
+ for i, aspect in enumerate(label_cols):
652
+ y_true = all_labels[:, i]
653
+ y_pred = all_preds[:, i]
654
+
655
+ acc = accuracy_score(y_true, y_pred)
656
+ prec = precision_score(y_true, y_pred, zero_division=0)
657
+ rec = recall_score(y_true, y_pred, zero_division=0)
658
+ f1 = f1_score(y_true, y_pred, zero_division=0)
659
+
660
+ print(f"\n=== {aspect.upper()} ===")
661
+ print(f"Accuracy: {acc:.4f}")
662
+ print(f"Precision: {prec:.4f}")
663
+ print(f"Recall: {rec:.4f}")
664
+ print(f"F1 Score: {f1:.4f}")
665
+
666
+ tp = np.sum((y_true == 1) & (y_pred == 1))
667
+ tn = np.sum((y_true == 0) & (y_pred == 0))
668
+ fp = np.sum((y_true == 0) & (y_pred == 1))
669
+ fn = np.sum((y_true == 1) & (y_pred == 0))
670
+
671
+ print(f" TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
672
+
673
+ # Exact match accuracy
674
+ exact_matches = np.all(all_preds == all_labels, axis=1)
675
+ exact_match_acc = np.mean(exact_matches)
676
+
677
+ print("\n" + "="*60)
678
+ print("EXACT MATCH (ALL ASPECTS)")
679
+ print("="*60)
680
+ print(f"Samples with ALL aspects correct: {np.sum(exact_matches)}/{len(exact_matches)}")
681
+ print(f"Exact Match Accuracy: {exact_match_acc:.4f}")
682
+
683
+ # Partial match accuracy (per sample)
684
+ test_partial_match_scores = []
685
+ for i in range(len(all_labels)):
686
+ correct_labels = np.sum(all_preds[i] == all_labels[i])
687
+ test_partial_match_scores.append(correct_labels / len(label_cols))
688
+
689
+ test_partial_match_scores = np.array(test_partial_match_scores)
690
+ avg_test_partial_match = np.mean(test_partial_match_scores)
691
+
692
+ print("\n" + "="*60)
693
+ print("PARTIAL MATCH (PER-SAMPLE LABEL ACCURACY)")
694
+ print("="*60)
695
+ print(f"Average Partial Match: {avg_test_partial_match:.4f} ({avg_test_partial_match:.2%})")
696
+ print(f"(Average fraction of labels correctly predicted per sample)")
697
+
698
+ # Sample predictions
699
+ print("\n" + "="*60)
700
+ print("SAMPLE PREDICTIONS VS GROUND TRUTH")
701
+ print("="*60)
702
+
703
+ num_samples = min(10, len(test_X))
704
+ print(f"\nShowing {num_samples} test samples:\n")
705
+
706
+ for idx in range(num_samples):
707
+ review = test_X[idx]
708
+ true_labels = [label_cols[i] for i, v in enumerate(all_labels[idx]) if v == 1]
709
+ pred_labels = [label_cols[i] for i, v in enumerate(all_preds[idx]) if v == 1]
710
+
711
+ # Calculate partial match for this sample
712
+ # Count how many true labels were correctly predicted
713
+ matching_labels = len(set(true_labels) & set(pred_labels))
714
+ total_true_labels = len(true_labels) if len(true_labels) > 0 else 1
715
+ partial_match = matching_labels / total_true_labels
716
+
717
+ review_display = review[:150] + "..." if len(review) > 150 else review
718
+ print(f"Sample {idx + 1}:")
719
+ print(f"Review: {review_display}")
720
+ print(f"✓ True Labels: {true_labels if true_labels else ['None']}")
721
+ print(f"→ Predicted Labels: {pred_labels if pred_labels else ['None']}")
722
+ print(f"Match: {'✓ Exact' if set(true_labels) == set(pred_labels) else '✗ Mismatch'}")
723
+ print(f"Partial Match: {matching_labels}/{total_true_labels} labels correct ({partial_match:.2%})")
724
+ print("-" * 40)
725
+
726
+ # Save model interactively (optional)
727
+ # model_save_path = 'gemma_delivery_classifier.pth'
728
+ # torch.save({
729
+ # 'epoch': EPOCHS,
730
+ # 'model_state_dict': model.state_dict(),
731
+ # 'optimizer_state_dict': optimizer.state_dict(),
732
+ # 'train_loss': avg_train_loss,
733
+ # 'test_loss': avg_test_loss,
734
+ # }, model_save_path)
735
+ # print(f"Model saved to {model_save_path}")
736
+ model_save_path = os.path.join(SAVE_DIR, 'gemma_delivery_classifier.pth')
737
+ torch.save({
738
+ 'epoch': best_epoch if best_model_state is not None else EPOCHS,
739
+ 'model_state_dict': model.state_dict(),
740
+ 'optimizer_state_dict': optimizer.state_dict(),
741
+ 'train_loss': train_losses[best_epoch - 1] if best_model_state is not None else train_losses[-1] if train_losses else 0,
742
+ 'val_loss': best_val_loss if best_model_state is not None else (val_losses[-1] if val_losses else 0),
743
+ 'best_epoch': best_epoch,
744
+ 'best_val_loss': best_val_loss,
745
+ 'optimal_thresholds': optimal_thresholds,
746
+ }, model_save_path)
747
+ print(f"Model saved to {model_save_path}")
748
+
749
+ print("\n" + "="*60)
750
+ print("TRAINING COMPLETE")
751
+ print("="*60)
6 _ Fine-Tuning (Gemma)/Specific Models/LLM trained Gemma Model/gemini_price_model.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torch import nn
5
+ from transformers import AutoTokenizer, GemmaModel
6
+ from peft import LoraConfig, get_peft_model, TaskType
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.metrics import classification_report, hamming_loss, accuracy_score, precision_score, recall_score, f1_score
9
+ import numpy as np
10
+ import random
11
+ import matplotlib.pyplot as plt
12
+ import os
13
+
14
+ # For UTF-8 characters in output
15
+ import sys
16
+ sys.stdout.reconfigure(encoding='utf-8')
17
+
18
+ # Set random seeds for reproducibility
19
+ seed_value = 42
20
+ random.seed(seed_value)
21
+ np.random.seed(seed_value)
22
+ torch.manual_seed(seed_value)
23
+ if torch.cuda.is_available():
24
+ torch.cuda.manual_seed_all(seed_value)
25
+
26
+ # Parameters
27
+ MODEL_ID = 'google/gemma-3-1b-pt'
28
+ BATCH_SIZE = 8
29
+ EPOCHS = 10
30
+ LR = 5e-5
31
+
32
+ # Load data - price-specific
33
+ print("Loading training data from price_train_dataset.csv...")
34
+ train_df = pd.read_csv('datasets/gemini/price_train_dataset.csv')
35
+ print("Loading test data from Test_price_dataset.csv...")
36
+ test_df = pd.read_csv('datasets/test_price_dataset.csv')
37
+
38
+ # Define label columns (Price sub-aspects)
39
+ label_cols = [
40
+ 'Affordability_PRICE',
41
+ 'Value_for_Money_PRICE',
42
+ 'General_PRICE'
43
+ ]
44
+
45
+ # Prepare training data with 80/20 train/validation split
46
+ train_X_full = train_df['Review'].astype(str).tolist()
47
+ train_Y_full = train_df[label_cols].values.astype(np.float32)
48
+
49
+ train_X, val_X, train_Y, val_Y = train_test_split(
50
+ train_X_full, train_Y_full,
51
+ test_size=0.2,
52
+ random_state=42
53
+ )
54
+
55
+ # Prepare test data
56
+ test_X = test_df['Review'].astype(str).tolist()
57
+ test_Y = test_df[label_cols].values.astype(np.float32)
58
+
59
+ print(f"\nDataset sizes:")
60
+ print(f"Training samples: {len(train_X)}")
61
+ print(f"Validation samples: {len(val_X)}")
62
+ print(f"Test samples: {len(test_X)}")
63
+ print(f"Number of labels: {len(label_cols)}")
64
+
65
+ # Compute class weights for imbalanced dataset
66
+ def compute_class_weights(labels, label_names):
67
+ """
68
+ Compute class weights for multi-label classification
69
+ using the inverse of class frequency.
70
+
71
+ Args:
72
+ labels: numpy array of shape (n_samples, n_labels)
73
+ label_names: list of label column names
74
+
75
+ Returns:
76
+ pos_weight: torch tensor of positive class weights
77
+ """
78
+ n_samples = labels.shape[0]
79
+ n_labels = labels.shape[1]
80
+
81
+ pos_weights = []
82
+
83
+ print("\n" + "="*60)
84
+ print("CLASS IMBALANCE ANALYSIS")
85
+ print("="*60)
86
+
87
+ for i, label_name in enumerate(label_names):
88
+ pos_count = np.sum(labels[:, i] == 1)
89
+ neg_count = np.sum(labels[:, i] == 0)
90
+
91
+ # Calculate positive class weight (ratio of negative to positive)
92
+ if pos_count > 0:
93
+ raw_ratio = neg_count / pos_count
94
+ # Apply square root dampening to avoid extreme weights
95
+ pos_weight = np.sqrt(raw_ratio)
96
+ else:
97
+ pos_weight = 1.0
98
+
99
+ pos_weights.append(pos_weight)
100
+
101
+ print(f"\n{label_name}:")
102
+ print(f" Positive samples: {pos_count} ({pos_count/n_samples*100:.2f}%)")
103
+ print(f" Negative samples: {neg_count} ({neg_count/n_samples*100:.2f}%)")
104
+ print(f" Raw imbalance ratio (neg/pos): {neg_count/pos_count if pos_count > 0 else 1.0:.4f}")
105
+ print(f" Dampened weight (sqrt of ratio): {pos_weight:.4f}")
106
+
107
+ print("="*60 + "\n")
108
+
109
+ return torch.FloatTensor(pos_weights)
110
+
111
+ def find_optimal_thresholds(model, dataloader, label_cols, device):
112
+ """
113
+ Find optimal decision threshold for each class independently
114
+ by maximizing F1-score on the validation set.
115
+
116
+ Args:
117
+ model: trained model
118
+ dataloader: validation data loader
119
+ label_cols: list of label column names
120
+ device: torch device
121
+
122
+ Returns:
123
+ optimal_thresholds: numpy array of optimal thresholds for each class
124
+ """
125
+ from sklearn.metrics import f1_score
126
+
127
+ print("\n" + "="*60)
128
+ print("OPTIMIZING DECISION THRESHOLDS")
129
+ print("="*60)
130
+
131
+ # Collect all predictions and labels
132
+ model.eval()
133
+ all_probs = []
134
+ all_labels = []
135
+
136
+ with torch.no_grad():
137
+ for input_ids, attention_mask, labels in dataloader:
138
+ input_ids = input_ids.to(device)
139
+ attention_mask = attention_mask.to(device)
140
+ logits = model(input_ids, attention_mask)
141
+ probs = torch.sigmoid(logits).cpu().numpy()
142
+ all_probs.append(probs)
143
+ all_labels.append(labels.cpu().numpy())
144
+
145
+ all_probs = np.vstack(all_probs)
146
+ all_labels = np.vstack(all_labels)
147
+
148
+ # Find optimal threshold for each class
149
+ optimal_thresholds = []
150
+ threshold_range = np.arange(0.1, 0.91, 0.05) # 0.1 to 0.9 in steps of 0.05
151
+
152
+ for i, label_name in enumerate(label_cols):
153
+ best_threshold = 0.5
154
+ best_f1 = 0.0
155
+
156
+ for threshold in threshold_range:
157
+ preds = (all_probs[:, i] > threshold).astype(int)
158
+ f1 = f1_score(all_labels[:, i], preds, zero_division=0)
159
+
160
+ if f1 > best_f1:
161
+ best_f1 = f1
162
+ best_threshold = threshold
163
+
164
+ optimal_thresholds.append(best_threshold)
165
+ print(f"\n{label_name}:")
166
+ print(f" Optimal threshold: {best_threshold:.2f}")
167
+ print(f" Best F1-score: {best_f1:.4f}")
168
+ print(f" (Default 0.5 threshold F1: {f1_score(all_labels[:, i], (all_probs[:, i] > 0.5).astype(int), zero_division=0):.4f})")
169
+
170
+ print("="*60 + "\n")
171
+
172
+ return np.array(optimal_thresholds)
173
+
174
+ def predict_with_thresholds(model, dataloader, thresholds, device):
175
+ """
176
+ Make predictions using custom thresholds for each class.
177
+
178
+ Args:
179
+ model: trained model
180
+ dataloader: data loader
181
+ thresholds: numpy array of thresholds for each class
182
+ device: torch device
183
+
184
+ Returns:
185
+ predictions: numpy array of predictions
186
+ labels: numpy array of true labels
187
+ """
188
+ model.eval()
189
+ all_preds = []
190
+ all_labels = []
191
+
192
+ with torch.no_grad():
193
+ for input_ids, attention_mask, labels in dataloader:
194
+ input_ids = input_ids.to(device)
195
+ attention_mask = attention_mask.to(device)
196
+ logits = model(input_ids, attention_mask)
197
+ probs = torch.sigmoid(logits).cpu().numpy()
198
+
199
+ # Apply custom thresholds for each class
200
+ preds = np.zeros_like(probs, dtype=int)
201
+ for i in range(len(thresholds)):
202
+ preds[:, i] = (probs[:, i] > thresholds[i]).astype(int)
203
+
204
+ all_preds.append(preds)
205
+ all_labels.append(labels.cpu().numpy())
206
+
207
+ return np.vstack(all_preds), np.vstack(all_labels)
208
+
209
+ # Dataset class
210
+ class ReviewDataset(Dataset):
211
+ def __init__(self, texts, labels):
212
+ self.texts = texts
213
+ self.labels = labels
214
+
215
+ def __len__(self):
216
+ return len(self.texts)
217
+
218
+ def __getitem__(self, idx):
219
+ encoding = tokenizer(
220
+ self.texts[idx],
221
+ padding='max_length',
222
+ truncation=True,
223
+ max_length=256,
224
+ return_tensors='pt'
225
+ )
226
+ input_ids = encoding['input_ids'].squeeze()
227
+ attention_mask = encoding['attention_mask'].squeeze()
228
+ label = torch.FloatTensor(self.labels[idx])
229
+ return input_ids, attention_mask, label
230
+
231
+ # Initialize tokenizer
232
+ print("\nInitializing tokenizer...")
233
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=True)
234
+
235
+ # Create datasets
236
+ train_dataset = ReviewDataset(train_X, train_Y)
237
+ val_dataset = ReviewDataset(val_X, val_Y)
238
+ test_dataset = ReviewDataset(test_X, test_Y)
239
+
240
+ # Create data loaders
241
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
242
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
243
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
244
+
245
+ # Compute class weights based on training data
246
+ print("Computing class weights for imbalanced dataset...")
247
+ pos_weights = compute_class_weights(train_Y, label_cols)
248
+
249
+ # Initialize model with LoRA
250
+ print("Initializing model with LoRA...")
251
+ backbone = GemmaModel.from_pretrained(MODEL_ID, token=True, dtype=torch.bfloat16)
252
+
253
+ lora_config = LoraConfig(
254
+ task_type=TaskType.FEATURE_EXTRACTION,
255
+ r=8,
256
+ lora_alpha=16,
257
+ lora_dropout=0.05,
258
+ target_modules=["q_proj", "v_proj"]
259
+ )
260
+ backbone = get_peft_model(backbone, lora_config)
261
+
262
+ # Classifier model
263
+ class GemmaClassifier(nn.Module):
264
+ def __init__(self, backbone, num_labels):
265
+ super().__init__()
266
+ self.backbone = backbone
267
+ self.pooler = nn.AdaptiveAvgPool1d(1)
268
+ self.classifier = nn.Linear(backbone.config.hidden_size, num_labels)
269
+
270
+ def forward(self, input_ids, attention_mask):
271
+ output = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
272
+ hidden = output.last_hidden_state
273
+ pooled = self.pooler(hidden.permute(0, 2, 1)).squeeze(-1)
274
+ logits = self.classifier(pooled.float())
275
+ return logits
276
+
277
+ # Initialize model, optimizer, and loss function
278
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
279
+ print(f"Using device: {device}")
280
+
281
+ model = GemmaClassifier(backbone, len(label_cols)).to(device)
282
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
283
+ # Use computed pos_weight to handle class imbalance
284
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights.to(device))
285
+ print(f"\nInitialized BCEWithLogitsLoss with pos_weight: {pos_weights.cpu().numpy()}")
286
+
287
+ # Initialize loss tracking
288
+ train_losses = []
289
+ val_losses = []
290
+ train_batch_losses = [] # Per-batch training losses
291
+ val_batch_losses = [] # Per-batch validation losses
292
+
293
+ # Early stopping variables
294
+ best_val_loss = float('inf')
295
+ best_epoch = 0
296
+ best_model_state = None
297
+ patience = 5 # Number of epochs to wait for improvement
298
+ patience_counter = 0
299
+
300
+ # Training loop
301
+ print("\n" + "="*60)
302
+ print("TRAINING")
303
+ print("="*60)
304
+
305
+ for epoch in range(EPOCHS):
306
+ model.train()
307
+ total_loss = 0
308
+ batch_count = 0
309
+
310
+ for input_ids, attention_mask, labels in train_loader:
311
+ input_ids = input_ids.to(device)
312
+ attention_mask = attention_mask.to(device)
313
+ labels = labels.to(device)
314
+
315
+ optimizer.zero_grad()
316
+ logits = model(input_ids, attention_mask)
317
+ loss = criterion(logits, labels)
318
+ loss.backward()
319
+ optimizer.step()
320
+
321
+ total_loss += loss.item()
322
+ batch_count += 1
323
+ train_batch_losses.append(loss.item()) # Store per-batch loss
324
+
325
+ # Print progress every 100 batches
326
+ if batch_count % 100 == 0:
327
+ print(f" Epoch {epoch+1} | Batch {batch_count}/{len(train_loader)} | Current Loss: {loss.item():.4f}")
328
+
329
+ avg_train_loss = total_loss / len(train_loader)
330
+ train_losses.append(avg_train_loss)
331
+ print(f"\nEpoch {epoch+1}/{EPOCHS} completed")
332
+ print(f"Average Training Loss: {avg_train_loss:.4f}")
333
+
334
+ # Validation on validation set
335
+ model.eval()
336
+ val_loss = 0
337
+ with torch.no_grad():
338
+ for input_ids, attention_mask, labels in val_loader:
339
+ input_ids = input_ids.to(device)
340
+ attention_mask = attention_mask.to(device)
341
+ labels = labels.to(device)
342
+
343
+ logits = model(input_ids, attention_mask)
344
+ loss = criterion(logits, labels)
345
+ val_loss += loss.item()
346
+ val_batch_losses.append(loss.item()) # Store per-batch validation loss
347
+
348
+ avg_val_loss = val_loss / len(val_loader)
349
+ val_losses.append(avg_val_loss)
350
+ print(f"Validation Loss: {avg_val_loss:.4f}")
351
+
352
+ # Early stopping check
353
+ if avg_val_loss < best_val_loss:
354
+ best_val_loss = avg_val_loss
355
+ best_epoch = epoch + 1
356
+ best_model_state = model.state_dict().copy()
357
+ patience_counter = 0
358
+ print(f"✓ New best validation loss: {best_val_loss:.4f} (Epoch {best_epoch})")
359
+ else:
360
+ patience_counter += 1
361
+ print(f" No improvement for {patience_counter} epoch(s)")
362
+ if patience_counter >= patience:
363
+ print(f"\nEarly stopping triggered! Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}")
364
+ break
365
+
366
+ print("-" * 60)
367
+
368
+ # Load best model state
369
+ if best_model_state is not None:
370
+ print(f"\nLoading best model from epoch {best_epoch} with validation loss: {best_val_loss:.4f}")
371
+ model.load_state_dict(best_model_state)
372
+ else:
373
+ print("\nNo best model found, using final model state")
374
+
375
+ # Optimize decision thresholds using validation set
376
+ print("Finding optimal decision thresholds for each class...")
377
+ optimal_thresholds = find_optimal_thresholds(model, val_loader, label_cols, device)
378
+ print(f"Optimal thresholds: {optimal_thresholds}")
379
+
380
+ # SAVE MODEL AFTER TRAINING
381
+ # SAVE_PATH = "gemma_price_specific.pt"
382
+ # torch.save(model.state_dict(), SAVE_PATH)
383
+ # print(f"\nModel saved to: {SAVE_PATH}")
384
+ SAVE_DIR = r"C:\temp\new_models" # make sure this folder exists
385
+ os.makedirs(SAVE_DIR, exist_ok=True)
386
+ SAVE_PATH = os.path.join(SAVE_DIR, "gemma_price_specific.pt")
387
+ torch.save(model.to('cpu').state_dict(), SAVE_PATH)
388
+ model.to(device) # Move model back to device after saving
389
+ print(f"\nModel saved to: {SAVE_PATH}")
390
+
391
+ # Plot training and validation loss
392
+ print("\n" + "="*60)
393
+ print("PLOTTING TRAINING CURVES")
394
+ print("="*60)
395
+
396
+ plt.figure(figsize=(10, 6))
397
+ epochs_range = range(1, EPOCHS + 1)
398
+
399
+ plt.plot(epochs_range, train_losses, 'b-o', label='Training Loss', linewidth=2, markersize=8)
400
+ plt.plot(epochs_range, val_losses, 'r-s', label='Validation Loss', linewidth=2, markersize=8)
401
+
402
+ plt.xlabel('Epoch', fontsize=12)
403
+ plt.ylabel('Loss', fontsize=12)
404
+ plt.title('Training and Validation Loss Over Epochs', fontsize=14, fontweight='bold')
405
+ plt.legend(fontsize=10)
406
+ plt.grid(True, alpha=0.3)
407
+ plt.tight_layout()
408
+
409
+ # Save the plot
410
+ plot_path = 'training_loss_plot_price.png'
411
+ plt.savefig(plot_path, dpi=300, bbox_inches='tight')
412
+ print(f"Training loss plot saved to: {plot_path}")
413
+
414
+ # Display loss values
415
+ print("\nLoss values per epoch:")
416
+ print("-" * 40)
417
+ for i, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses), 1):
418
+ print(f"Epoch {i}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
419
+ print("-" * 40)
420
+
421
+ # Plot detailed per-batch loss curves
422
+ print("\nGenerating detailed per-batch loss plot...")
423
+
424
+ # Create figure with two subplots
425
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
426
+
427
+ # Calculate moving average for smoothing (window size = 50 batches)
428
+ def moving_average(data, window_size):
429
+ if len(data) < window_size:
430
+ window_size = max(1, len(data) // 2)
431
+ cumsum = np.cumsum(np.insert(data, 0, 0))
432
+ return (cumsum[window_size:] - cumsum[:-window_size]) / window_size
433
+
434
+ train_ma = moving_average(train_batch_losses, 50)
435
+ val_ma = moving_average(val_batch_losses, 50)
436
+
437
+ # Subplot 1: Training loss per batch
438
+ ax1.plot(train_batch_losses, alpha=0.3, color='lightblue', linewidth=0.5, label='Raw Training Loss')
439
+ ax1.plot(range(len(train_ma)), train_ma, color='blue', linewidth=2, label='Smoothed (Moving Avg, window=50)')
440
+ ax1.set_xlabel('Training Batch', fontsize=11)
441
+ ax1.set_ylabel('Loss', fontsize=11)
442
+ ax1.set_title('Training Loss per Batch (Detailed View)', fontsize=13, fontweight='bold')
443
+ ax1.legend(fontsize=9)
444
+ ax1.grid(True, alpha=0.3)
445
+
446
+ # Add vertical lines for epoch boundaries
447
+ batches_per_epoch = len(train_loader)
448
+ for epoch_idx in range(1, EPOCHS):
449
+ ax1.axvline(x=epoch_idx * batches_per_epoch, color='red', linestyle='--', linewidth=1, alpha=0.5)
450
+
451
+ # Subplot 2: Validation loss per batch
452
+ ax2.plot(val_batch_losses, alpha=0.3, color='lightcoral', linewidth=0.5, label='Raw Validation Loss')
453
+ ax2.plot(range(len(val_ma)), val_ma, color='red', linewidth=2, label='Smoothed (Moving Avg, window=50)')
454
+ ax2.set_xlabel('Validation Batch', fontsize=11)
455
+ ax2.set_ylabel('Loss', fontsize=11)
456
+ ax2.set_title('Validation Loss per Batch (Detailed View)', fontsize=13, fontweight='bold')
457
+ ax2.legend(fontsize=9)
458
+ ax2.grid(True, alpha=0.3)
459
+
460
+ # Add vertical lines for epoch boundaries
461
+ val_batches_per_epoch = len(val_loader)
462
+ for epoch_idx in range(1, EPOCHS):
463
+ ax2.axvline(x=epoch_idx * val_batches_per_epoch, color='blue', linestyle='--', linewidth=1, alpha=0.5)
464
+
465
+ plt.tight_layout()
466
+
467
+ # Save the detailed plot
468
+ detailed_plot_path = 'training_loss_per_batch_detailed_price.png'
469
+ plt.savefig(detailed_plot_path, dpi=300, bbox_inches='tight')
470
+ print(f"Detailed per-batch loss plot saved to: {detailed_plot_path}")
471
+
472
+ # Print batch loss statistics
473
+ print("\nBatch Loss Statistics:")
474
+ print("-" * 60)
475
+ print(f"Training batches: {len(train_batch_losses)}")
476
+ print(f" Min loss: {min(train_batch_losses):.4f}")
477
+ print(f" Max loss: {max(train_batch_losses):.4f}")
478
+ print(f" Mean loss: {np.mean(train_batch_losses):.4f}")
479
+ print(f" Std dev: {np.std(train_batch_losses):.4f}")
480
+ print(f"\nValidation batches: {len(val_batch_losses)}")
481
+ print(f" Min loss: {min(val_batch_losses):.4f}")
482
+ print(f" Max loss: {max(val_batch_losses):.4f}")
483
+ print(f" Mean loss: {np.mean(val_batch_losses):.4f}")
484
+ print(f" Std dev: {np.std(val_batch_losses):.4f}")
485
+ print("-" * 60)
486
+
487
+ # VALIDATION SET EVALUATION (WITH OPTIMIZED THRESHOLDS)
488
+ print("\n" + "="*60)
489
+ print("VALIDATION SET EVALUATION (WITH OPTIMIZED THRESHOLDS)")
490
+ print("="*60)
491
+
492
+ val_preds, val_labels_eval = predict_with_thresholds(model, val_loader, optimal_thresholds, device)
493
+
494
+ # Also get predictions with default threshold for comparison
495
+ model.eval()
496
+ val_preds_default = []
497
+ with torch.no_grad():
498
+ for input_ids, attention_mask, labels in val_loader:
499
+ input_ids = input_ids.to(device)
500
+ attention_mask = attention_mask.to(device)
501
+ logits = model(input_ids, attention_mask)
502
+ probs = torch.sigmoid(logits).cpu().numpy()
503
+ preds = (probs > 0.5).astype(int)
504
+ val_preds_default.append(preds)
505
+
506
+ val_preds_default = np.vstack(val_preds_default)
507
+
508
+ print(f"\nPredicted data shape: {val_preds.shape}")
509
+ print(f"Ground truth data shape: {val_labels_eval.shape}")
510
+
511
+ # Comparison: Default vs Optimized Thresholds
512
+ print("\n" + "="*60)
513
+ print("COMPARISON: Default vs Optimized Thresholds")
514
+ print("="*60)
515
+
516
+ print("\nDefault Threshold (0.5):")
517
+ for i, label in enumerate(label_cols):
518
+ f1_default = f1_score(val_labels_eval[:, i], val_preds_default[:, i], zero_division=0)
519
+ print(f" {label}: F1 = {f1_default:.4f}")
520
+
521
+ print("\nOptimized Thresholds:")
522
+ for i, label in enumerate(label_cols):
523
+ f1_optimized = f1_score(val_labels_eval[:, i], val_preds[:, i], zero_division=0)
524
+ print(f" {label}: F1 = {f1_optimized:.4f} (threshold = {optimal_thresholds[i]:.2f})")
525
+ print("="*60 + "\n")
526
+
527
+ # Classification Report
528
+ print('\n' + '='*60)
529
+ print('CLASSIFICATION REPORT (VALIDATION)')
530
+ print('='*60)
531
+ print(classification_report(val_labels_eval, val_preds, target_names=label_cols))
532
+
533
+ # Hamming Loss
534
+ val_hamming_loss = hamming_loss(val_labels_eval, val_preds)
535
+ print("="*60)
536
+ print("HAMMING LOSS (Multi-label Error Rate)")
537
+ print("="*60)
538
+ print(f"Hamming Loss: {val_hamming_loss:.4f}")
539
+ print(f"(Fraction of incorrectly predicted labels: {val_hamming_loss:.2%})")
540
+
541
+ # Per-aspect metrics
542
+ print("\n" + "="*60)
543
+ print("PER-ASPECT METRICS (VALIDATION)")
544
+ print("="*60)
545
+
546
+ for i, aspect in enumerate(label_cols):
547
+ y_true = val_labels_eval[:, i]
548
+ y_pred = val_preds[:, i]
549
+
550
+ acc = accuracy_score(y_true, y_pred)
551
+ prec = precision_score(y_true, y_pred, zero_division=0)
552
+ rec = recall_score(y_true, y_pred, zero_division=0)
553
+ f1 = f1_score(y_true, y_pred, zero_division=0)
554
+
555
+ print(f"\n=== {aspect.upper()} ===")
556
+ print(f"Accuracy: {acc:.4f}")
557
+ print(f"Precision: {prec:.4f}")
558
+ print(f"Recall: {rec:.4f}")
559
+ print(f"F1 Score: {f1:.4f}")
560
+
561
+ tp = np.sum((y_true == 1) & (y_pred == 1))
562
+ tn = np.sum((y_true == 0) & (y_pred == 0))
563
+ fp = np.sum((y_true == 0) & (y_pred == 1))
564
+ fn = np.sum((y_true == 1) & (y_pred == 0))
565
+
566
+ print(f" TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
567
+
568
+ # Exact match accuracy
569
+ val_exact_matches = np.all(val_preds == val_labels_eval, axis=1)
570
+ val_exact_match_acc = np.mean(val_exact_matches)
571
+
572
+ print("\n" + "="*60)
573
+ print("EXACT MATCH (ALL ASPECTS)")
574
+ print("="*60)
575
+ print(f"Samples with ALL aspects correct: {np.sum(val_exact_matches)}/{len(val_exact_matches)}")
576
+ print(f"Exact Match Accuracy: {val_exact_match_acc:.4f}")
577
+
578
+ # Partial match accuracy (per sample)
579
+ partial_match_scores = []
580
+ for i in range(len(val_labels_eval)):
581
+ correct_labels = np.sum(val_preds[i] == val_labels_eval[i])
582
+ partial_match_scores.append(correct_labels / len(label_cols))
583
+
584
+ partial_match_scores = np.array(partial_match_scores)
585
+ avg_partial_match = np.mean(partial_match_scores)
586
+
587
+ print("\n" + "="*60)
588
+ print("PARTIAL MATCH (PER-SAMPLE LABEL ACCURACY)")
589
+ print("="*60)
590
+ print(f"Average Partial Match: {avg_partial_match:.4f} ({avg_partial_match:.2%})")
591
+ print(f"(Average fraction of labels correctly predicted per sample)")
592
+
593
+ # Sample predictions with match/mismatch
594
+ print("\n" + "="*60)
595
+ print("SAMPLE PREDICTIONS VS GROUND TRUTH (VALIDATION)")
596
+ print("="*60)
597
+
598
+ num_samples = min(10, len(val_X))
599
+ print(f"\nShowing {num_samples} validation samples:\n")
600
+
601
+ for idx in range(num_samples):
602
+ review = val_X[idx]
603
+ true_labels = [label_cols[i] for i, v in enumerate(val_labels_eval[idx]) if v == 1]
604
+ pred_labels = [label_cols[i] for i, v in enumerate(val_preds[idx]) if v == 1]
605
+
606
+ # Calculate partial match for this sample
607
+ # Count how many true labels were correctly predicted
608
+ matching_labels = len(set(true_labels) & set(pred_labels))
609
+ total_true_labels = len(true_labels) if len(true_labels) > 0 else 1
610
+ partial_match = matching_labels / total_true_labels
611
+
612
+ review_display = review[:150] + "..." if len(review) > 150 else review
613
+ print(f"Sample {idx + 1}:")
614
+ print(f"Review: {review_display}")
615
+ print(f"✓ True Labels: {true_labels if true_labels else ['None']}")
616
+ print(f"→ Predicted Labels: {pred_labels if pred_labels else ['None']}")
617
+ print(f"Match: {'✓ Exact' if set(true_labels) == set(pred_labels) else '✗ Mismatch'}")
618
+ print(f"Partial Match: {matching_labels}/{total_true_labels} labels correct ({partial_match:.2%})")
619
+ print("-" * 40)
620
+
621
+ # Final Evaluation on Test Set (WITH OPTIMIZED THRESHOLDS)
622
+ print("\n" + "="*60)
623
+ print("FINAL EVALUATION ON TEST SET (WITH OPTIMIZED THRESHOLDS)")
624
+ print("="*60)
625
+
626
+ all_preds, all_labels = predict_with_thresholds(model, test_loader, optimal_thresholds, device)
627
+
628
+ print(f"\nPredicted data shape: {all_preds.shape}")
629
+ print(f"Ground truth data shape: {all_labels.shape}")
630
+
631
+ # Classification Report
632
+ print('\n' + '='*60)
633
+ print('CLASSIFICATION REPORT')
634
+ print('='*60)
635
+ print(classification_report(all_labels, all_preds, target_names=label_cols))
636
+
637
+ # Hamming Loss
638
+ hamming_loss_value = hamming_loss(all_labels, all_preds)
639
+ print("="*60)
640
+ print("HAMMING LOSS (Multi-label Error Rate)")
641
+ print("="*60)
642
+ print(f"Hamming Loss: {hamming_loss_value:.4f}")
643
+ print(f"(Fraction of incorrectly predicted labels: {hamming_loss_value:.2%})")
644
+
645
+ # Per-aspect metrics
646
+ print("\n" + "="*60)
647
+ print("PER-ASPECT METRICS")
648
+ print("="*60)
649
+
650
+ for i, aspect in enumerate(label_cols):
651
+ y_true = all_labels[:, i]
652
+ y_pred = all_preds[:, i]
653
+
654
+ acc = accuracy_score(y_true, y_pred)
655
+ prec = precision_score(y_true, y_pred, zero_division=0)
656
+ rec = recall_score(y_true, y_pred, zero_division=0)
657
+ f1 = f1_score(y_true, y_pred, zero_division=0)
658
+
659
+ print(f"\n=== {aspect.upper()} ===")
660
+ print(f"Accuracy: {acc:.4f}")
661
+ print(f"Precision: {prec:.4f}")
662
+ print(f"Recall: {rec:.4f}")
663
+ print(f"F1 Score: {f1:.4f}")
664
+
665
+ tp = np.sum((y_true == 1) & (y_pred == 1))
666
+ tn = np.sum((y_true == 0) & (y_pred == 0))
667
+ fp = np.sum((y_true == 0) & (y_pred == 1))
668
+ fn = np.sum((y_true == 1) & (y_pred == 0))
669
+
670
+ print(f" TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
671
+
672
+ # Exact match accuracy
673
+ exact_matches = np.all(all_preds == all_labels, axis=1)
674
+ exact_match_acc = np.mean(exact_matches)
675
+
676
+ print("\n" + "="*60)
677
+ print("EXACT MATCH (ALL ASPECTS)")
678
+ print("="*60)
679
+ print(f"Samples with ALL aspects correct: {np.sum(exact_matches)}/{len(exact_matches)}")
680
+ print(f"Exact Match Accuracy: {exact_match_acc:.4f}")
681
+
682
+ # Partial match accuracy (per sample)
683
+ test_partial_match_scores = []
684
+ for i in range(len(all_labels)):
685
+ correct_labels = np.sum(all_preds[i] == all_labels[i])
686
+ test_partial_match_scores.append(correct_labels / len(label_cols))
687
+
688
+ test_partial_match_scores = np.array(test_partial_match_scores)
689
+ avg_test_partial_match = np.mean(test_partial_match_scores)
690
+
691
+ print("\n" + "="*60)
692
+ print("PARTIAL MATCH (PER-SAMPLE LABEL ACCURACY)")
693
+ print("="*60)
694
+ print(f"Average Partial Match: {avg_test_partial_match:.4f} ({avg_test_partial_match:.2%})")
695
+ print(f"(Average fraction of labels correctly predicted per sample)")
696
+
697
+ # Sample predictions
698
+ print("\n" + "="*60)
699
+ print("SAMPLE PREDICTIONS VS GROUND TRUTH")
700
+ print("="*60)
701
+
702
+ num_samples = min(10, len(test_X))
703
+ print(f"\nShowing {num_samples} test samples:\n")
704
+
705
+ for idx in range(num_samples):
706
+ review = test_X[idx]
707
+ true_labels = [label_cols[i] for i, v in enumerate(all_labels[idx]) if v == 1]
708
+ pred_labels = [label_cols[i] for i, v in enumerate(all_preds[idx]) if v == 1]
709
+
710
+ # Calculate partial match for this sample
711
+ # Count how many true labels were correctly predicted
712
+ matching_labels = len(set(true_labels) & set(pred_labels))
713
+ total_true_labels = len(true_labels) if len(true_labels) > 0 else 1
714
+ partial_match = matching_labels / total_true_labels
715
+
716
+ review_display = review[:150] + "..." if len(review) > 150 else review
717
+ print(f"Sample {idx + 1}:")
718
+ print(f"Review: {review_display}")
719
+ print(f"✓ True Labels: {true_labels if true_labels else ['None']}")
720
+ print(f"→ Predicted Labels: {pred_labels if pred_labels else ['None']}")
721
+ print(f"Match: {'✓ Exact' if set(true_labels) == set(pred_labels) else '✗ Mismatch'}")
722
+ print(f"Partial Match: {matching_labels}/{total_true_labels} labels correct ({partial_match:.2%})")
723
+ print("-" * 40)
724
+
725
+ # Save model interactively (optional)
726
+ # model_save_path = 'gemma_price_classifier.pth'
727
+ # torch.save({
728
+ # 'epoch': EPOCHS,
729
+ # 'model_state_dict': model.state_dict(),
730
+ # 'optimizer_state_dict': optimizer.state_dict(),
731
+ # 'train_loss': avg_train_loss,
732
+ # 'test_loss': avg_test_loss,
733
+ # }, model_save_path)
734
+ # print(f"Model saved to {model_save_path}")
735
+
736
+ model_save_path = os.path.join(SAVE_DIR, 'gemma_price_classifier.pth')
737
+ torch.save({
738
+ 'epoch': best_epoch if best_model_state is not None else EPOCHS,
739
+ 'model_state_dict': model.state_dict(),
740
+ 'optimizer_state_dict': optimizer.state_dict(),
741
+ 'train_loss': train_losses[best_epoch - 1] if best_model_state is not None else train_losses[-1] if train_losses else 0,
742
+ 'val_loss': best_val_loss if best_model_state is not None else (val_losses[-1] if val_losses else 0),
743
+ 'best_epoch': best_epoch,
744
+ 'best_val_loss': best_val_loss,
745
+ 'optimal_thresholds': optimal_thresholds,
746
+ }, model_save_path)
747
+ print(f"Model saved to {model_save_path}")
748
+
749
+ print("\n" + "="*60)
750
+ print("TRAINING COMPLETE")
751
+ print("="*60)
6 _ Fine-Tuning (Gemma)/Specific Models/LLM trained Gemma Model/gemini_product_model.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torch import nn
5
+ from transformers import AutoTokenizer, GemmaModel
6
+ from peft import LoraConfig, get_peft_model, TaskType
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.metrics import classification_report, hamming_loss, accuracy_score, precision_score, recall_score, f1_score
9
+ import numpy as np
10
+ import random
11
+ import matplotlib.pyplot as plt
12
+
13
+ # For UTF-8 characters in output
14
+ import sys
15
+ sys.stdout.reconfigure(encoding='utf-8')
16
+
17
+ # Set random seeds for reproducibility
18
+ seed_value = 42
19
+ random.seed(seed_value)
20
+ np.random.seed(seed_value)
21
+ torch.manual_seed(seed_value)
22
+ if torch.cuda.is_available():
23
+ torch.cuda.manual_seed_all(seed_value)
24
+
25
+ # Parameters
26
+ MODEL_ID = 'google/gemma-3-1b-pt'
27
+ BATCH_SIZE = 8
28
+ EPOCHS = 10
29
+ LR = 5e-5
30
+
31
+ # Load data - product-specific
32
+ print("Loading training data from product_specific_aspects.csv...")
33
+ train_df = pd.read_csv('datasets/gemini/product_train_dataset.csv')
34
+ print("Loading test data from Test_product_dataset.csv...")
35
+ test_df = pd.read_csv('datasets/test_product_dataset.csv')
36
+
37
+ # Define label columns (Product sub-aspects)
38
+ label_cols = [
39
+ 'Color_PRO',
40
+ 'Condition_PRO',
41
+ 'Correctness_PRO',
42
+ 'Durability_PRO',
43
+ 'Effectiveness_PRO',
44
+ 'Functionality_PRO',
45
+ 'Material_PRO',
46
+ 'Sensory_PRO',
47
+ 'Size_PRO',
48
+ 'General_PRO'
49
+ ]
50
+
51
+ # Prepare training data with 80/20 train/validation split
52
+ train_X_full = train_df['Review'].astype(str).tolist()
53
+ train_Y_full = train_df[label_cols].values.astype(np.float32)
54
+
55
+ train_X, val_X, train_Y, val_Y = train_test_split(
56
+ train_X_full, train_Y_full,
57
+ test_size=0.2,
58
+ random_state=42
59
+ )
60
+
61
+ # Prepare test data
62
+ test_X = test_df['Review'].astype(str).tolist()
63
+ test_Y = test_df[label_cols].values.astype(np.float32)
64
+
65
+ print(f"\nDataset sizes:")
66
+ print(f"Training samples: {len(train_X)}")
67
+ print(f"Validation samples: {len(val_X)}")
68
+ print(f"Test samples: {len(test_X)}")
69
+ print(f"Number of labels: {len(label_cols)}")
70
+
71
+ # Compute class weights for imbalanced dataset
72
+ def compute_class_weights(labels, label_names):
73
+ """
74
+ Compute class weights for multi-label classification
75
+ using the inverse of class frequency.
76
+
77
+ Args:
78
+ labels: numpy array of shape (n_samples, n_labels)
79
+ label_names: list of label column names
80
+
81
+ Returns:
82
+ pos_weight: torch tensor of positive class weights
83
+ """
84
+ n_samples = labels.shape[0]
85
+ n_labels = labels.shape[1]
86
+
87
+ pos_weights = []
88
+
89
+ print("\n" + "="*60)
90
+ print("CLASS IMBALANCE ANALYSIS")
91
+ print("="*60)
92
+
93
+ for i, label_name in enumerate(label_names):
94
+ pos_count = np.sum(labels[:, i] == 1)
95
+ neg_count = np.sum(labels[:, i] == 0)
96
+
97
+ # Calculate positive class weight (ratio of negative to positive)
98
+ if pos_count > 0:
99
+ raw_ratio = neg_count / pos_count
100
+ # Apply square root dampening to avoid extreme weights
101
+ pos_weight = np.sqrt(raw_ratio)
102
+ else:
103
+ pos_weight = 1.0
104
+
105
+ pos_weights.append(pos_weight)
106
+
107
+ print(f"\n{label_name}:")
108
+ print(f" Positive samples: {pos_count} ({pos_count/n_samples*100:.2f}%)")
109
+ print(f" Negative samples: {neg_count} ({neg_count/n_samples*100:.2f}%)")
110
+ print(f" Raw imbalance ratio (neg/pos): {neg_count/pos_count if pos_count > 0 else 1.0:.4f}")
111
+ print(f" Dampened weight (sqrt of ratio): {pos_weight:.4f}")
112
+
113
+ print("="*60 + "\n")
114
+
115
+ return torch.FloatTensor(pos_weights)
116
+
117
+ def find_optimal_thresholds(model, dataloader, label_cols, device):
118
+ """
119
+ Find optimal decision threshold for each class independently
120
+ by maximizing F1-score on the validation set.
121
+
122
+ Args:
123
+ model: trained model
124
+ dataloader: validation data loader
125
+ label_cols: list of label column names
126
+ device: torch device
127
+
128
+ Returns:
129
+ optimal_thresholds: numpy array of optimal thresholds for each class
130
+ """
131
+ from sklearn.metrics import f1_score
132
+
133
+ print("\n" + "="*60)
134
+ print("OPTIMIZING DECISION THRESHOLDS")
135
+ print("="*60)
136
+
137
+ # Collect all predictions and labels
138
+ model.eval()
139
+ all_probs = []
140
+ all_labels = []
141
+
142
+ with torch.no_grad():
143
+ for input_ids, attention_mask, labels in dataloader:
144
+ input_ids = input_ids.to(device)
145
+ attention_mask = attention_mask.to(device)
146
+ logits = model(input_ids, attention_mask)
147
+ probs = torch.sigmoid(logits).cpu().numpy()
148
+ all_probs.append(probs)
149
+ all_labels.append(labels.cpu().numpy())
150
+
151
+ all_probs = np.vstack(all_probs)
152
+ all_labels = np.vstack(all_labels)
153
+
154
+ # Find optimal threshold for each class
155
+ optimal_thresholds = []
156
+ threshold_range = np.arange(0.1, 0.91, 0.05) # 0.1 to 0.9 in steps of 0.05
157
+
158
+ for i, label_name in enumerate(label_cols):
159
+ best_threshold = 0.5
160
+ best_f1 = 0.0
161
+
162
+ for threshold in threshold_range:
163
+ preds = (all_probs[:, i] > threshold).astype(int)
164
+ f1 = f1_score(all_labels[:, i], preds, zero_division=0)
165
+
166
+ if f1 > best_f1:
167
+ best_f1 = f1
168
+ best_threshold = threshold
169
+
170
+ optimal_thresholds.append(best_threshold)
171
+ print(f"\n{label_name}:")
172
+ print(f" Optimal threshold: {best_threshold:.2f}")
173
+ print(f" Best F1-score: {best_f1:.4f}")
174
+ print(f" (Default 0.5 threshold F1: {f1_score(all_labels[:, i], (all_probs[:, i] > 0.5).astype(int), zero_division=0):.4f})")
175
+
176
+ print("="*60 + "\n")
177
+
178
+ return np.array(optimal_thresholds)
179
+
180
+ def predict_with_thresholds(model, dataloader, thresholds, device):
181
+ """
182
+ Make predictions using custom thresholds for each class.
183
+
184
+ Args:
185
+ model: trained model
186
+ dataloader: data loader
187
+ thresholds: numpy array of thresholds for each class
188
+ device: torch device
189
+
190
+ Returns:
191
+ predictions: numpy array of predictions
192
+ labels: numpy array of true labels
193
+ """
194
+ model.eval()
195
+ all_preds = []
196
+ all_labels = []
197
+
198
+ with torch.no_grad():
199
+ for input_ids, attention_mask, labels in dataloader:
200
+ input_ids = input_ids.to(device)
201
+ attention_mask = attention_mask.to(device)
202
+ logits = model(input_ids, attention_mask)
203
+ probs = torch.sigmoid(logits).cpu().numpy()
204
+
205
+ # Apply custom thresholds for each class
206
+ preds = np.zeros_like(probs, dtype=int)
207
+ for i in range(len(thresholds)):
208
+ preds[:, i] = (probs[:, i] > thresholds[i]).astype(int)
209
+
210
+ all_preds.append(preds)
211
+ all_labels.append(labels.cpu().numpy())
212
+
213
+ return np.vstack(all_preds), np.vstack(all_labels)
214
+
215
+ # Dataset class
216
+ class ReviewDataset(Dataset):
217
+ def __init__(self, texts, labels):
218
+ self.texts = texts
219
+ self.labels = labels
220
+
221
+ def __len__(self):
222
+ return len(self.texts)
223
+
224
+ def __getitem__(self, idx):
225
+ encoding = tokenizer(
226
+ self.texts[idx],
227
+ padding='max_length',
228
+ truncation=True,
229
+ max_length=256,
230
+ return_tensors='pt'
231
+ )
232
+ input_ids = encoding['input_ids'].squeeze()
233
+ attention_mask = encoding['attention_mask'].squeeze()
234
+ label = torch.FloatTensor(self.labels[idx])
235
+ return input_ids, attention_mask, label
236
+
237
+ # Initialize tokenizer
238
+ print("\nInitializing tokenizer...")
239
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=True)
240
+
241
+ # Create datasets
242
+ train_dataset = ReviewDataset(train_X, train_Y)
243
+ val_dataset = ReviewDataset(val_X, val_Y)
244
+ test_dataset = ReviewDataset(test_X, test_Y)
245
+
246
+ # Create data loaders
247
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
248
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
249
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
250
+
251
+ # Compute class weights based on training data
252
+ print("Computing class weights for imbalanced dataset...")
253
+ pos_weights = compute_class_weights(train_Y, label_cols)
254
+
255
+ # Initialize model with LoRA
256
+ print("Initializing model with LoRA...")
257
+ backbone = GemmaModel.from_pretrained(MODEL_ID, token=True, dtype=torch.bfloat16)
258
+
259
+ lora_config = LoraConfig(
260
+ task_type=TaskType.FEATURE_EXTRACTION,
261
+ r=8,
262
+ lora_alpha=16,
263
+ lora_dropout=0.05,
264
+ target_modules=["q_proj", "v_proj"]
265
+ )
266
+ backbone = get_peft_model(backbone, lora_config)
267
+
268
+ # Classifier model
269
+ class GemmaClassifier(nn.Module):
270
+ def __init__(self, backbone, num_labels):
271
+ super().__init__()
272
+ self.backbone = backbone
273
+ self.pooler = nn.AdaptiveAvgPool1d(1)
274
+ self.classifier = nn.Linear(backbone.config.hidden_size, num_labels)
275
+
276
+ def forward(self, input_ids, attention_mask):
277
+ output = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
278
+ hidden = output.last_hidden_state
279
+ pooled = self.pooler(hidden.permute(0, 2, 1)).squeeze(-1)
280
+ logits = self.classifier(pooled.float())
281
+ return logits
282
+
283
+ # Initialize model, optimizer, and loss function
284
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
285
+ print(f"Using device: {device}")
286
+
287
+ model = GemmaClassifier(backbone, len(label_cols)).to(device)
288
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
289
+ # Use computed pos_weight to handle class imbalance
290
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights.to(device))
291
+ print(f"\nInitialized BCEWithLogitsLoss with pos_weight: {pos_weights.cpu().numpy()}")
292
+
293
+ # Initialize loss tracking
294
+ train_losses = []
295
+ val_losses = []
296
+ train_batch_losses = [] # Per-batch training losses
297
+ val_batch_losses = [] # Per-batch validation losses
298
+
299
+ # Early stopping variables
300
+ best_val_loss = float('inf')
301
+ best_epoch = 0
302
+ best_model_state = None
303
+ patience = 5 # Number of epochs to wait for improvement
304
+ patience_counter = 0
305
+
306
+ # Training loop
307
+ print("\n" + "="*60)
308
+ print("TRAINING")
309
+ print("="*60)
310
+
311
+ for epoch in range(EPOCHS):
312
+ model.train()
313
+ total_loss = 0
314
+ batch_count = 0
315
+
316
+ for input_ids, attention_mask, labels in train_loader:
317
+ input_ids = input_ids.to(device)
318
+ attention_mask = attention_mask.to(device)
319
+ labels = labels.to(device)
320
+
321
+ optimizer.zero_grad()
322
+ logits = model(input_ids, attention_mask)
323
+ loss = criterion(logits, labels)
324
+ loss.backward()
325
+ optimizer.step()
326
+
327
+ total_loss += loss.item()
328
+ batch_count += 1
329
+ train_batch_losses.append(loss.item()) # Store per-batch loss
330
+
331
+ # Print progress every 100 batches
332
+ if batch_count % 100 == 0:
333
+ print(f" Epoch {epoch+1} | Batch {batch_count}/{len(train_loader)} | Current Loss: {loss.item():.4f}")
334
+
335
+ avg_train_loss = total_loss / len(train_loader)
336
+ train_losses.append(avg_train_loss)
337
+ print(f"\nEpoch {epoch+1}/{EPOCHS} completed")
338
+ print(f"Average Training Loss: {avg_train_loss:.4f}")
339
+
340
+ # Validation on validation set
341
+ model.eval()
342
+ val_loss = 0
343
+ with torch.no_grad():
344
+ for input_ids, attention_mask, labels in val_loader:
345
+ input_ids = input_ids.to(device)
346
+ attention_mask = attention_mask.to(device)
347
+ labels = labels.to(device)
348
+
349
+ logits = model(input_ids, attention_mask)
350
+ loss = criterion(logits, labels)
351
+ val_loss += loss.item()
352
+ val_batch_losses.append(loss.item()) # Store per-batch validation loss
353
+
354
+ avg_val_loss = val_loss / len(val_loader)
355
+ val_losses.append(avg_val_loss)
356
+ print(f"Validation Loss: {avg_val_loss:.4f}")
357
+
358
+ # Early stopping check
359
+ if avg_val_loss < best_val_loss:
360
+ best_val_loss = avg_val_loss
361
+ best_epoch = epoch + 1
362
+ best_model_state = model.state_dict().copy()
363
+ patience_counter = 0
364
+ print(f"✓ New best validation loss: {best_val_loss:.4f} (Epoch {best_epoch})")
365
+ else:
366
+ patience_counter += 1
367
+ print(f" No improvement for {patience_counter} epoch(s)")
368
+ if patience_counter >= patience:
369
+ print(f"\nEarly stopping triggered! Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}")
370
+ break
371
+
372
+ print("-" * 60)
373
+
374
+ # Load best model state
375
+ if best_model_state is not None:
376
+ print(f"\nLoading best model from epoch {best_epoch} with validation loss: {best_val_loss:.4f}")
377
+ model.load_state_dict(best_model_state)
378
+ else:
379
+ print("\nNo best model found, using final model state")
380
+
381
+ # Optimize decision thresholds using validation set
382
+ print("Finding optimal decision thresholds for each class...")
383
+ optimal_thresholds = find_optimal_thresholds(model, val_loader, label_cols, device)
384
+ print(f"Optimal thresholds: {optimal_thresholds}")
385
+
386
+ # SAVE MODEL AFTER TRAINING
387
+ SAVE_PATH = "gemma_product_specific.pt"
388
+ torch.save(model.state_dict(), SAVE_PATH)
389
+ print(f"\nModel saved to: {SAVE_PATH}")
390
+
391
+ # Plot training and validation loss
392
+ print("\n" + "="*60)
393
+ print("PLOTTING TRAINING CURVES")
394
+ print("="*60)
395
+
396
+ plt.figure(figsize=(10, 6))
397
+ epochs_range = range(1, EPOCHS + 1)
398
+
399
+ plt.plot(epochs_range, train_losses, 'b-o', label='Training Loss', linewidth=2, markersize=8)
400
+ plt.plot(epochs_range, val_losses, 'r-s', label='Validation Loss', linewidth=2, markersize=8)
401
+
402
+ plt.xlabel('Epoch', fontsize=12)
403
+ plt.ylabel('Loss', fontsize=12)
404
+ plt.title('Training and Validation Loss Over Epochs', fontsize=14, fontweight='bold')
405
+ plt.legend(fontsize=10)
406
+ plt.grid(True, alpha=0.3)
407
+ plt.tight_layout()
408
+
409
+ # Save the plot
410
+ plot_path = 'training_loss_plot.png'
411
+ plt.savefig(plot_path, dpi=300, bbox_inches='tight')
412
+ print(f"Training loss plot saved to: {plot_path}")
413
+
414
+ # Display loss values
415
+ print("\nLoss values per epoch:")
416
+ print("-" * 40)
417
+ for i, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses), 1):
418
+ print(f"Epoch {i}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
419
+ print("-" * 40)
420
+
421
+ # Plot detailed per-batch loss curves
422
+ print("\nGenerating detailed per-batch loss plot...")
423
+
424
+ # Create figure with two subplots
425
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
426
+
427
+ # Calculate moving average for smoothing (window size = 50 batches)
428
+ def moving_average(data, window_size):
429
+ if len(data) < window_size:
430
+ window_size = max(1, len(data) // 2)
431
+ cumsum = np.cumsum(np.insert(data, 0, 0))
432
+ return (cumsum[window_size:] - cumsum[:-window_size]) / window_size
433
+
434
+ train_ma = moving_average(train_batch_losses, 50)
435
+ val_ma = moving_average(val_batch_losses, 50)
436
+
437
+ # Subplot 1: Training loss per batch
438
+ ax1.plot(train_batch_losses, alpha=0.3, color='lightblue', linewidth=0.5, label='Raw Training Loss')
439
+ ax1.plot(range(len(train_ma)), train_ma, color='blue', linewidth=2, label='Smoothed (Moving Avg, window=50)')
440
+ ax1.set_xlabel('Training Batch', fontsize=11)
441
+ ax1.set_ylabel('Loss', fontsize=11)
442
+ ax1.set_title('Training Loss per Batch (Detailed View)', fontsize=13, fontweight='bold')
443
+ ax1.legend(fontsize=9)
444
+ ax1.grid(True, alpha=0.3)
445
+
446
+ # Add vertical lines for epoch boundaries
447
+ batches_per_epoch = len(train_loader)
448
+ for epoch_idx in range(1, EPOCHS):
449
+ ax1.axvline(x=epoch_idx * batches_per_epoch, color='red', linestyle='--', linewidth=1, alpha=0.5)
450
+
451
+ # Subplot 2: Validation loss per batch
452
+ ax2.plot(val_batch_losses, alpha=0.3, color='lightcoral', linewidth=0.5, label='Raw Validation Loss')
453
+ ax2.plot(range(len(val_ma)), val_ma, color='red', linewidth=2, label='Smoothed (Moving Avg, window=50)')
454
+ ax2.set_xlabel('Validation Batch', fontsize=11)
455
+ ax2.set_ylabel('Loss', fontsize=11)
456
+ ax2.set_title('Validation Loss per Batch (Detailed View)', fontsize=13, fontweight='bold')
457
+ ax2.legend(fontsize=9)
458
+ ax2.grid(True, alpha=0.3)
459
+
460
+ # Add vertical lines for epoch boundaries
461
+ val_batches_per_epoch = len(val_loader)
462
+ for epoch_idx in range(1, EPOCHS):
463
+ ax2.axvline(x=epoch_idx * val_batches_per_epoch, color='blue', linestyle='--', linewidth=1, alpha=0.5)
464
+
465
+ plt.tight_layout()
466
+
467
+ # Save the detailed plot
468
+ detailed_plot_path = 'training_loss_per_batch_detailed.png'
469
+ plt.savefig(detailed_plot_path, dpi=300, bbox_inches='tight')
470
+ print(f"Detailed per-batch loss plot saved to: {detailed_plot_path}")
471
+
472
+ # Print batch loss statistics
473
+ print("\nBatch Loss Statistics:")
474
+ print("-" * 60)
475
+ print(f"Training batches: {len(train_batch_losses)}")
476
+ print(f" Min loss: {min(train_batch_losses):.4f}")
477
+ print(f" Max loss: {max(train_batch_losses):.4f}")
478
+ print(f" Mean loss: {np.mean(train_batch_losses):.4f}")
479
+ print(f" Std dev: {np.std(train_batch_losses):.4f}")
480
+ print(f"\nValidation batches: {len(val_batch_losses)}")
481
+ print(f" Min loss: {min(val_batch_losses):.4f}")
482
+ print(f" Max loss: {max(val_batch_losses):.4f}")
483
+ print(f" Mean loss: {np.mean(val_batch_losses):.4f}")
484
+ print(f" Std dev: {np.std(val_batch_losses):.4f}")
485
+ print("-" * 60)
486
+
487
+ # VALIDATION SET EVALUATION (WITH OPTIMIZED THRESHOLDS)
488
+ print("\n" + "="*60)
489
+ print("VALIDATION SET EVALUATION (WITH OPTIMIZED THRESHOLDS)")
490
+ print("="*60)
491
+
492
+ val_preds, val_labels_eval = predict_with_thresholds(model, val_loader, optimal_thresholds, device)
493
+
494
+ # Also get predictions with default threshold for comparison
495
+ model.eval()
496
+ val_preds_default = []
497
+ with torch.no_grad():
498
+ for input_ids, attention_mask, labels in val_loader:
499
+ input_ids = input_ids.to(device)
500
+ attention_mask = attention_mask.to(device)
501
+ logits = model(input_ids, attention_mask)
502
+ probs = torch.sigmoid(logits).cpu().numpy()
503
+ preds = (probs > 0.5).astype(int)
504
+ val_preds_default.append(preds)
505
+
506
+ val_preds_default = np.vstack(val_preds_default)
507
+
508
+ print(f"\nPredicted data shape: {val_preds.shape}")
509
+ print(f"Ground truth data shape: {val_labels_eval.shape}")
510
+
511
+ # Comparison: Default vs Optimized Thresholds
512
+ print("\n" + "="*60)
513
+ print("COMPARISON: Default vs Optimized Thresholds")
514
+ print("="*60)
515
+
516
+ print("\nDefault Threshold (0.5):")
517
+ for i, label in enumerate(label_cols):
518
+ f1_default = f1_score(val_labels_eval[:, i], val_preds_default[:, i], zero_division=0)
519
+ print(f" {label}: F1 = {f1_default:.4f}")
520
+
521
+ print("\nOptimized Thresholds:")
522
+ for i, label in enumerate(label_cols):
523
+ f1_optimized = f1_score(val_labels_eval[:, i], val_preds[:, i], zero_division=0)
524
+ print(f" {label}: F1 = {f1_optimized:.4f} (threshold = {optimal_thresholds[i]:.2f})")
525
+ print("="*60 + "\n")
526
+
527
+ # Classification Report
528
+ print('\n' + '='*60)
529
+ print('CLASSIFICATION REPORT (VALIDATION)')
530
+ print('='*60)
531
+ print(classification_report(val_labels_eval, val_preds, target_names=label_cols))
532
+
533
+ # Hamming Loss
534
+ val_hamming_loss = hamming_loss(val_labels_eval, val_preds)
535
+ print("="*60)
536
+ print("HAMMING LOSS (Multi-label Error Rate)")
537
+ print("="*60)
538
+ print(f"Hamming Loss: {val_hamming_loss:.4f}")
539
+ print(f"(Fraction of incorrectly predicted labels: {val_hamming_loss:.2%})")
540
+
541
+ # Per-aspect metrics
542
+ print("\n" + "="*60)
543
+ print("PER-ASPECT METRICS (VALIDATION)")
544
+ print("="*60)
545
+
546
+ for i, aspect in enumerate(label_cols):
547
+ y_true = val_labels_eval[:, i]
548
+ y_pred = val_preds[:, i]
549
+
550
+ acc = accuracy_score(y_true, y_pred)
551
+ prec = precision_score(y_true, y_pred, zero_division=0)
552
+ rec = recall_score(y_true, y_pred, zero_division=0)
553
+ f1 = f1_score(y_true, y_pred, zero_division=0)
554
+
555
+ print(f"\n=== {aspect.upper()} ===")
556
+ print(f"Accuracy: {acc:.4f}")
557
+ print(f"Precision: {prec:.4f}")
558
+ print(f"Recall: {rec:.4f}")
559
+ print(f"F1 Score: {f1:.4f}")
560
+
561
+ tp = np.sum((y_true == 1) & (y_pred == 1))
562
+ tn = np.sum((y_true == 0) & (y_pred == 0))
563
+ fp = np.sum((y_true == 0) & (y_pred == 1))
564
+ fn = np.sum((y_true == 1) & (y_pred == 0))
565
+
566
+ print(f" TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
567
+
568
+ # Exact match accuracy
569
+ val_exact_matches = np.all(val_preds == val_labels_eval, axis=1)
570
+ val_exact_match_acc = np.mean(val_exact_matches)
571
+
572
+ print("\n" + "="*60)
573
+ print("EXACT MATCH (ALL ASPECTS)")
574
+ print("="*60)
575
+ print(f"Samples with ALL aspects correct: {np.sum(val_exact_matches)}/{len(val_exact_matches)}")
576
+ print(f"Exact Match Accuracy: {val_exact_match_acc:.4f}")
577
+
578
+ # Partial match accuracy (per sample)
579
+ partial_match_scores = []
580
+ for i in range(len(val_labels_eval)):
581
+ correct_labels = np.sum(val_preds[i] == val_labels_eval[i])
582
+ partial_match_scores.append(correct_labels / len(label_cols))
583
+
584
+ partial_match_scores = np.array(partial_match_scores)
585
+ avg_partial_match = np.mean(partial_match_scores)
586
+
587
+ print("\n" + "="*60)
588
+ print("PARTIAL MATCH (PER-SAMPLE LABEL ACCURACY)")
589
+ print("="*60)
590
+ print(f"Average Partial Match: {avg_partial_match:.4f} ({avg_partial_match:.2%})")
591
+ print(f"(Average fraction of labels correctly predicted per sample)")
592
+
593
+ # Sample predictions with match/mismatch
594
+ print("\n" + "="*60)
595
+ print("SAMPLE PREDICTIONS VS GROUND TRUTH (VALIDATION)")
596
+ print("="*60)
597
+
598
+ num_samples = min(10, len(val_X))
599
+ print(f"\nShowing {num_samples} validation samples:\n")
600
+
601
+ for idx in range(num_samples):
602
+ review = val_X[idx]
603
+ true_labels = [label_cols[i] for i, v in enumerate(val_labels_eval[idx]) if v == 1]
604
+ pred_labels = [label_cols[i] for i, v in enumerate(val_preds[idx]) if v == 1]
605
+
606
+ # Calculate partial match for this sample
607
+ # Count how many true labels were correctly predicted
608
+ matching_labels = len(set(true_labels) & set(pred_labels))
609
+ total_true_labels = len(true_labels) if len(true_labels) > 0 else 1
610
+ partial_match = matching_labels / total_true_labels
611
+
612
+ review_display = review[:150] + "..." if len(review) > 150 else review
613
+ print(f"Sample {idx + 1}:")
614
+ print(f"Review: {review_display}")
615
+ print(f"✓ True Labels: {true_labels if true_labels else ['None']}")
616
+ print(f"→ Predicted Labels: {pred_labels if pred_labels else ['None']}")
617
+ print(f"Match: {'✓ Exact' if set(true_labels) == set(pred_labels) else '✗ Mismatch'}")
618
+ print(f"Partial Match: {matching_labels}/{total_true_labels} labels correct ({partial_match:.2%})")
619
+ print("-" * 40)
620
+
621
+ # Final Evaluation on Test Set (WITH OPTIMIZED THRESHOLDS)
622
+ print("\n" + "="*60)
623
+ print("FINAL EVALUATION ON TEST SET (WITH OPTIMIZED THRESHOLDS)")
624
+ print("="*60)
625
+
626
+ all_preds, all_labels = predict_with_thresholds(model, test_loader, optimal_thresholds, device)
627
+
628
+ print(f"\nPredicted data shape: {all_preds.shape}")
629
+ print(f"Ground truth data shape: {all_labels.shape}")
630
+
631
+ # Classification Report
632
+ print('\n' + '='*60)
633
+ print('CLASSIFICATION REPORT')
634
+ print('='*60)
635
+ print(classification_report(all_labels, all_preds, target_names=label_cols))
636
+
637
+ # Hamming Loss
638
+ hamming_loss_value = hamming_loss(all_labels, all_preds)
639
+ print("="*60)
640
+ print("HAMMING LOSS (Multi-label Error Rate)")
641
+ print("="*60)
642
+ print(f"Hamming Loss: {hamming_loss_value:.4f}")
643
+ print(f"(Fraction of incorrectly predicted labels: {hamming_loss_value:.2%})")
644
+
645
+ # Per-aspect metrics
646
+ print("\n" + "="*60)
647
+ print("PER-ASPECT METRICS")
648
+ print("="*60)
649
+
650
+ for i, aspect in enumerate(label_cols):
651
+ y_true = all_labels[:, i]
652
+ y_pred = all_preds[:, i]
653
+
654
+ acc = accuracy_score(y_true, y_pred)
655
+ prec = precision_score(y_true, y_pred, zero_division=0)
656
+ rec = recall_score(y_true, y_pred, zero_division=0)
657
+ f1 = f1_score(y_true, y_pred, zero_division=0)
658
+
659
+ print(f"\n=== {aspect.upper()} ===")
660
+ print(f"Accuracy: {acc:.4f}")
661
+ print(f"Precision: {prec:.4f}")
662
+ print(f"Recall: {rec:.4f}")
663
+ print(f"F1 Score: {f1:.4f}")
664
+
665
+ tp = np.sum((y_true == 1) & (y_pred == 1))
666
+ tn = np.sum((y_true == 0) & (y_pred == 0))
667
+ fp = np.sum((y_true == 0) & (y_pred == 1))
668
+ fn = np.sum((y_true == 1) & (y_pred == 0))
669
+
670
+ print(f" TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
671
+
672
+ # Exact match accuracy
673
+ exact_matches = np.all(all_preds == all_labels, axis=1)
674
+ exact_match_acc = np.mean(exact_matches)
675
+
676
+ print("\n" + "="*60)
677
+ print("EXACT MATCH (ALL ASPECTS)")
678
+ print("="*60)
679
+ print(f"Samples with ALL aspects correct: {np.sum(exact_matches)}/{len(exact_matches)}")
680
+ print(f"Exact Match Accuracy: {exact_match_acc:.4f}")
681
+
682
+ # Partial match accuracy (per sample)
683
+ test_partial_match_scores = []
684
+ for i in range(len(all_labels)):
685
+ correct_labels = np.sum(all_preds[i] == all_labels[i])
686
+ test_partial_match_scores.append(correct_labels / len(label_cols))
687
+
688
+ test_partial_match_scores = np.array(test_partial_match_scores)
689
+ avg_test_partial_match = np.mean(test_partial_match_scores)
690
+
691
+ print("\n" + "="*60)
692
+ print("PARTIAL MATCH (PER-SAMPLE LABEL ACCURACY)")
693
+ print("="*60)
694
+ print(f"Average Partial Match: {avg_test_partial_match:.4f} ({avg_test_partial_match:.2%})")
695
+ print(f"(Average fraction of labels correctly predicted per sample)")
696
+
697
+ # Sample predictions
698
+ print("\n" + "="*60)
699
+ print("SAMPLE PREDICTIONS VS GROUND TRUTH")
700
+ print("="*60)
701
+
702
+ num_samples = min(10, len(test_X))
703
+ print(f"\nShowing {num_samples} test samples:\n")
704
+
705
+ for idx in range(num_samples):
706
+ review = test_X[idx]
707
+ true_labels = [label_cols[i] for i, v in enumerate(all_labels[idx]) if v == 1]
708
+ pred_labels = [label_cols[i] for i, v in enumerate(all_preds[idx]) if v == 1]
709
+
710
+ # Calculate partial match for this sample
711
+ # Count how many true labels were correctly predicted
712
+ matching_labels = len(set(true_labels) & set(pred_labels))
713
+ total_true_labels = len(true_labels) if len(true_labels) > 0 else 1
714
+ partial_match = matching_labels / total_true_labels
715
+
716
+ review_display = review[:150] + "..." if len(review) > 150 else review
717
+ print(f"Sample {idx + 1}:")
718
+ print(f"Review: {review_display}")
719
+ print(f"✓ True Labels: {true_labels if true_labels else ['None']}")
720
+ print(f"→ Predicted Labels: {pred_labels if pred_labels else ['None']}")
721
+ print(f"Match: {'✓ Exact' if set(true_labels) == set(pred_labels) else '✗ Mismatch'}")
722
+ print(f"Partial Match: {matching_labels}/{total_true_labels} labels correct ({partial_match:.2%})")
723
+ print("-" * 40)
724
+
725
+ # Save model interactively (optional)
726
+ model_save_path = 'gemma_product_classifier.pth'
727
+ torch.save({
728
+ 'epoch': best_epoch if best_model_state is not None else EPOCHS,
729
+ 'model_state_dict': model.state_dict(),
730
+ 'optimizer_state_dict': optimizer.state_dict(),
731
+ 'train_loss': train_losses[best_epoch - 1] if best_model_state is not None else train_losses[-1] if train_losses else 0,
732
+ 'val_loss': best_val_loss if best_model_state is not None else (val_losses[-1] if val_losses else 0),
733
+ 'best_epoch': best_epoch,
734
+ 'best_val_loss': best_val_loss,
735
+ 'optimal_thresholds': optimal_thresholds,
736
+ }, model_save_path)
737
+ print(f"Model saved to {model_save_path}")
738
+
739
+ print("\n" + "="*60)
740
+ print("TRAINING COMPLETE")
741
+ print("="*60)
6 _ Fine-Tuning (Gemma)/Specific Models/LLM trained Gemma Model/gemini_service_model.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torch import nn
5
+ from transformers import AutoTokenizer, GemmaModel
6
+ from peft import LoraConfig, get_peft_model, TaskType
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.metrics import classification_report, hamming_loss, accuracy_score, precision_score, recall_score, f1_score
9
+ import numpy as np
10
+ import random
11
+ import matplotlib.pyplot as plt
12
+ import os
13
+
14
+ # For UTF-8 characters in output
15
+ import sys
16
+ sys.stdout.reconfigure(encoding='utf-8')
17
+
18
+ # Set random seeds for reproducibility
19
+ seed_value = 42
20
+ random.seed(seed_value)
21
+ np.random.seed(seed_value)
22
+ torch.manual_seed(seed_value)
23
+ if torch.cuda.is_available():
24
+ torch.cuda.manual_seed_all(seed_value)
25
+
26
+ # Parameters
27
+ MODEL_ID = 'google/gemma-3-1b-pt'
28
+ BATCH_SIZE = 8
29
+ EPOCHS = 10
30
+ LR = 5e-5
31
+
32
+ # Load data - service-specific
33
+ print("Loading training data from service_train_dataset.csv...")
34
+ train_df = pd.read_csv('datasets/gemini/service_train_dataset.csv')
35
+ print("Loading test data from test_service_dataset.csv...")
36
+ test_df = pd.read_csv('datasets/test_service_dataset.csv')
37
+
38
+ # Define label columns (Service sub-aspects)
39
+ label_cols = [
40
+ 'Handling_SER',
41
+ 'Responsiveness_SER',
42
+ 'Trustworthiness_SER',
43
+ 'General_SER'
44
+ ]
45
+
46
+ # Prepare training data with 80/20 train/validation split
47
+ train_X_full = train_df['Review'].astype(str).tolist()
48
+ train_Y_full = train_df[label_cols].values.astype(np.float32)
49
+
50
+ train_X, val_X, train_Y, val_Y = train_test_split(
51
+ train_X_full, train_Y_full,
52
+ test_size=0.2,
53
+ random_state=42
54
+ )
55
+
56
+ # Prepare test data
57
+ test_X = test_df['Review'].astype(str).tolist()
58
+ test_Y = test_df[label_cols].values.astype(np.float32)
59
+
60
+ print(f"\nDataset sizes:")
61
+ print(f"Training samples: {len(train_X)}")
62
+ print(f"Validation samples: {len(val_X)}")
63
+ print(f"Test samples: {len(test_X)}")
64
+ print(f"Number of labels: {len(label_cols)}")
65
+
66
+ # Compute class weights for imbalanced dataset
67
+ def compute_class_weights(labels, label_names):
68
+ """
69
+ Compute class weights for multi-label classification
70
+ using the inverse of class frequency.
71
+
72
+ Args:
73
+ labels: numpy array of shape (n_samples, n_labels)
74
+ label_names: list of label column names
75
+
76
+ Returns:
77
+ pos_weight: torch tensor of positive class weights
78
+ """
79
+ n_samples = labels.shape[0]
80
+ n_labels = labels.shape[1]
81
+
82
+ pos_weights = []
83
+
84
+ print("\n" + "="*60)
85
+ print("CLASS IMBALANCE ANALYSIS")
86
+ print("="*60)
87
+
88
+ for i, label_name in enumerate(label_names):
89
+ pos_count = np.sum(labels[:, i] == 1)
90
+ neg_count = np.sum(labels[:, i] == 0)
91
+
92
+ # Calculate positive class weight (ratio of negative to positive)
93
+ if pos_count > 0:
94
+ raw_ratio = neg_count / pos_count
95
+ # Apply square root dampening to avoid extreme weights
96
+ pos_weight = np.sqrt(raw_ratio)
97
+ else:
98
+ pos_weight = 1.0
99
+
100
+ pos_weights.append(pos_weight)
101
+
102
+ print(f"\n{label_name}:")
103
+ print(f" Positive samples: {pos_count} ({pos_count/n_samples*100:.2f}%)")
104
+ print(f" Negative samples: {neg_count} ({neg_count/n_samples*100:.2f}%)")
105
+ print(f" Raw imbalance ratio (neg/pos): {neg_count/pos_count if pos_count > 0 else 1.0:.4f}")
106
+ print(f" Dampened weight (sqrt of ratio): {pos_weight:.4f}")
107
+
108
+ print("="*60 + "\n")
109
+
110
+ return torch.FloatTensor(pos_weights)
111
+
112
+ def find_optimal_thresholds(model, dataloader, label_cols, device):
113
+ """
114
+ Find optimal decision threshold for each class independently
115
+ by maximizing F1-score on the validation set.
116
+
117
+ Args:
118
+ model: trained model
119
+ dataloader: validation data loader
120
+ label_cols: list of label column names
121
+ device: torch device
122
+
123
+ Returns:
124
+ optimal_thresholds: numpy array of optimal thresholds for each class
125
+ """
126
+ from sklearn.metrics import f1_score
127
+
128
+ print("\n" + "="*60)
129
+ print("OPTIMIZING DECISION THRESHOLDS")
130
+ print("="*60)
131
+
132
+ # Collect all predictions and labels
133
+ model.eval()
134
+ all_probs = []
135
+ all_labels = []
136
+
137
+ with torch.no_grad():
138
+ for input_ids, attention_mask, labels in dataloader:
139
+ input_ids = input_ids.to(device)
140
+ attention_mask = attention_mask.to(device)
141
+ logits = model(input_ids, attention_mask)
142
+ probs = torch.sigmoid(logits).cpu().numpy()
143
+ all_probs.append(probs)
144
+ all_labels.append(labels.cpu().numpy())
145
+
146
+ all_probs = np.vstack(all_probs)
147
+ all_labels = np.vstack(all_labels)
148
+
149
+ # Find optimal threshold for each class
150
+ optimal_thresholds = []
151
+ threshold_range = np.arange(0.1, 0.91, 0.05) # 0.1 to 0.9 in steps of 0.05
152
+
153
+ for i, label_name in enumerate(label_cols):
154
+ best_threshold = 0.5
155
+ best_f1 = 0.0
156
+
157
+ for threshold in threshold_range:
158
+ preds = (all_probs[:, i] > threshold).astype(int)
159
+ f1 = f1_score(all_labels[:, i], preds, zero_division=0)
160
+
161
+ if f1 > best_f1:
162
+ best_f1 = f1
163
+ best_threshold = threshold
164
+
165
+ optimal_thresholds.append(best_threshold)
166
+ print(f"\n{label_name}:")
167
+ print(f" Optimal threshold: {best_threshold:.2f}")
168
+ print(f" Best F1-score: {best_f1:.4f}")
169
+ print(f" (Default 0.5 threshold F1: {f1_score(all_labels[:, i], (all_probs[:, i] > 0.5).astype(int), zero_division=0):.4f})")
170
+
171
+ print("="*60 + "\n")
172
+
173
+ return np.array(optimal_thresholds)
174
+
175
+ def predict_with_thresholds(model, dataloader, thresholds, device):
176
+ """
177
+ Make predictions using custom thresholds for each class.
178
+
179
+ Args:
180
+ model: trained model
181
+ dataloader: data loader
182
+ thresholds: numpy array of thresholds for each class
183
+ device: torch device
184
+
185
+ Returns:
186
+ predictions: numpy array of predictions
187
+ labels: numpy array of true labels
188
+ """
189
+ model.eval()
190
+ all_preds = []
191
+ all_labels = []
192
+
193
+ with torch.no_grad():
194
+ for input_ids, attention_mask, labels in dataloader:
195
+ input_ids = input_ids.to(device)
196
+ attention_mask = attention_mask.to(device)
197
+ logits = model(input_ids, attention_mask)
198
+ probs = torch.sigmoid(logits).cpu().numpy()
199
+
200
+ # Apply custom thresholds for each class
201
+ preds = np.zeros_like(probs, dtype=int)
202
+ for i in range(len(thresholds)):
203
+ preds[:, i] = (probs[:, i] > thresholds[i]).astype(int)
204
+
205
+ all_preds.append(preds)
206
+ all_labels.append(labels.cpu().numpy())
207
+
208
+ return np.vstack(all_preds), np.vstack(all_labels)
209
+
210
+ # Dataset class
211
+ class ReviewDataset(Dataset):
212
+ def __init__(self, texts, labels):
213
+ self.texts = texts
214
+ self.labels = labels
215
+
216
+ def __len__(self):
217
+ return len(self.texts)
218
+
219
+ def __getitem__(self, idx):
220
+ encoding = tokenizer(
221
+ self.texts[idx],
222
+ padding='max_length',
223
+ truncation=True,
224
+ max_length=256,
225
+ return_tensors='pt'
226
+ )
227
+ input_ids = encoding['input_ids'].squeeze()
228
+ attention_mask = encoding['attention_mask'].squeeze()
229
+ label = torch.FloatTensor(self.labels[idx])
230
+ return input_ids, attention_mask, label
231
+
232
+ # Initialize tokenizer
233
+ print("\nInitializing tokenizer...")
234
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=True)
235
+
236
+ # Create datasets
237
+ train_dataset = ReviewDataset(train_X, train_Y)
238
+ val_dataset = ReviewDataset(val_X, val_Y)
239
+ test_dataset = ReviewDataset(test_X, test_Y)
240
+
241
+ # Create data loaders
242
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
243
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
244
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
245
+
246
+ # Compute class weights based on training data
247
+ print("Computing class weights for imbalanced dataset...")
248
+ pos_weights = compute_class_weights(train_Y, label_cols)
249
+
250
+ # Initialize model with LoRA
251
+ print("Initializing model with LoRA...")
252
+ backbone = GemmaModel.from_pretrained(MODEL_ID, token=True, dtype=torch.bfloat16)
253
+
254
+ lora_config = LoraConfig(
255
+ task_type=TaskType.FEATURE_EXTRACTION,
256
+ r=8,
257
+ lora_alpha=16,
258
+ lora_dropout=0.05,
259
+ target_modules=["q_proj", "v_proj"]
260
+ )
261
+ backbone = get_peft_model(backbone, lora_config)
262
+
263
+ # Classifier model
264
+ class GemmaClassifier(nn.Module):
265
+ def __init__(self, backbone, num_labels):
266
+ super().__init__()
267
+ self.backbone = backbone
268
+ self.pooler = nn.AdaptiveAvgPool1d(1)
269
+ self.classifier = nn.Linear(backbone.config.hidden_size, num_labels)
270
+
271
+ def forward(self, input_ids, attention_mask):
272
+ output = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
273
+ hidden = output.last_hidden_state
274
+ pooled = self.pooler(hidden.permute(0, 2, 1)).squeeze(-1)
275
+ logits = self.classifier(pooled.float())
276
+ return logits
277
+
278
+ # Initialize model, optimizer, and loss function
279
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
280
+ print(f"Using device: {device}")
281
+
282
+ model = GemmaClassifier(backbone, len(label_cols)).to(device)
283
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
284
+ # Use computed pos_weight to handle class imbalance
285
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights.to(device))
286
+ print(f"\nInitialized BCEWithLogitsLoss with pos_weight: {pos_weights.cpu().numpy()}")
287
+
288
+ # Initialize loss tracking
289
+ train_losses = []
290
+ val_losses = []
291
+ train_batch_losses = [] # Per-batch training losses
292
+ val_batch_losses = [] # Per-batch validation losses
293
+
294
+ # Early stopping variables
295
+ best_val_loss = float('inf')
296
+ best_epoch = 0
297
+ best_model_state = None
298
+ patience = 5 # Number of epochs to wait for improvement
299
+ patience_counter = 0
300
+
301
+ # Training loop
302
+ print("\n" + "="*60)
303
+ print("TRAINING")
304
+ print("="*60)
305
+
306
+ for epoch in range(EPOCHS):
307
+ model.train()
308
+ total_loss = 0
309
+ batch_count = 0
310
+
311
+ for input_ids, attention_mask, labels in train_loader:
312
+ input_ids = input_ids.to(device)
313
+ attention_mask = attention_mask.to(device)
314
+ labels = labels.to(device)
315
+
316
+ optimizer.zero_grad()
317
+ logits = model(input_ids, attention_mask)
318
+ loss = criterion(logits, labels)
319
+ loss.backward()
320
+ optimizer.step()
321
+
322
+ total_loss += loss.item()
323
+ batch_count += 1
324
+ train_batch_losses.append(loss.item()) # Store per-batch loss
325
+
326
+ # Print progress every 100 batches
327
+ if batch_count % 100 == 0:
328
+ print(f" Epoch {epoch+1} | Batch {batch_count}/{len(train_loader)} | Current Loss: {loss.item():.4f}")
329
+
330
+ avg_train_loss = total_loss / len(train_loader)
331
+ train_losses.append(avg_train_loss)
332
+ print(f"\nEpoch {epoch+1}/{EPOCHS} completed")
333
+ print(f"Average Training Loss: {avg_train_loss:.4f}")
334
+
335
+ # Validation on validation set
336
+ model.eval()
337
+ val_loss = 0
338
+ with torch.no_grad():
339
+ for input_ids, attention_mask, labels in val_loader:
340
+ input_ids = input_ids.to(device)
341
+ attention_mask = attention_mask.to(device)
342
+ labels = labels.to(device)
343
+
344
+ logits = model(input_ids, attention_mask)
345
+ loss = criterion(logits, labels)
346
+ val_loss += loss.item()
347
+ val_batch_losses.append(loss.item()) # Store per-batch validation loss
348
+
349
+ avg_val_loss = val_loss / len(val_loader)
350
+ val_losses.append(avg_val_loss)
351
+ print(f"Validation Loss: {avg_val_loss:.4f}")
352
+
353
+ # Early stopping check
354
+ if avg_val_loss < best_val_loss:
355
+ best_val_loss = avg_val_loss
356
+ best_epoch = epoch + 1
357
+ best_model_state = model.state_dict().copy()
358
+ patience_counter = 0
359
+ print(f"✓ New best validation loss: {best_val_loss:.4f} (Epoch {best_epoch})")
360
+ else:
361
+ patience_counter += 1
362
+ print(f" No improvement for {patience_counter} epoch(s)")
363
+ if patience_counter >= patience:
364
+ print(f"\nEarly stopping triggered! Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}")
365
+ break
366
+
367
+ print("-" * 60)
368
+
369
+ # Load best model state
370
+ if best_model_state is not None:
371
+ print(f"\nLoading best model from epoch {best_epoch} with validation loss: {best_val_loss:.4f}")
372
+ model.load_state_dict(best_model_state)
373
+ else:
374
+ print("\nNo best model found, using final model state")
375
+
376
+ # Optimize decision thresholds using validation set
377
+ print("Finding optimal decision thresholds for each class...")
378
+ optimal_thresholds = find_optimal_thresholds(model, val_loader, label_cols, device)
379
+ print(f"Optimal thresholds: {optimal_thresholds}")
380
+
381
+ # SAVE MODEL AFTER TRAINING
382
+ SAVE_DIR = r"C:\temp\new_models" # make sure this folder exists
383
+ os.makedirs(SAVE_DIR, exist_ok=True)
384
+ SAVE_PATH = os.path.join(SAVE_DIR, "gemma_service_specific.pt")
385
+ torch.save(model.to('cpu').state_dict(), SAVE_PATH)
386
+ model.to(device) # Move model back to device after saving
387
+ print(f"\nModel saved to: {SAVE_PATH}")
388
+
389
+ # Plot training and validation loss
390
+ print("\n" + "="*60)
391
+ print("PLOTTING TRAINING CURVES")
392
+ print("="*60)
393
+
394
+ plt.figure(figsize=(10, 6))
395
+ epochs_range = range(1, EPOCHS + 1)
396
+
397
+ plt.plot(epochs_range, train_losses, 'b-o', label='Training Loss', linewidth=2, markersize=8)
398
+ plt.plot(epochs_range, val_losses, 'r-s', label='Validation Loss', linewidth=2, markersize=8)
399
+
400
+ plt.xlabel('Epoch', fontsize=12)
401
+ plt.ylabel('Loss', fontsize=12)
402
+ plt.title('Training and Validation Loss Over Epochs', fontsize=14, fontweight='bold')
403
+ plt.legend(fontsize=10)
404
+ plt.grid(True, alpha=0.3)
405
+ plt.tight_layout()
406
+
407
+ # Save the plot
408
+ plot_path = 'training_loss_plot_service.png'
409
+ plt.savefig(plot_path, dpi=300, bbox_inches='tight')
410
+ print(f"Training loss plot saved to: {plot_path}")
411
+
412
+ # Display loss values
413
+ print("\nLoss values per epoch:")
414
+ print("-" * 40)
415
+ for i, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses), 1):
416
+ print(f"Epoch {i}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
417
+ print("-" * 40)
418
+
419
+ # Plot detailed per-batch loss curves
420
+ print("\nGenerating detailed per-batch loss plot...")
421
+
422
+ # Create figure with two subplots
423
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
424
+
425
+ # Calculate moving average for smoothing (window size = 50 batches)
426
+ def moving_average(data, window_size):
427
+ if len(data) < window_size:
428
+ window_size = max(1, len(data) // 2)
429
+ cumsum = np.cumsum(np.insert(data, 0, 0))
430
+ return (cumsum[window_size:] - cumsum[:-window_size]) / window_size
431
+
432
+ train_ma = moving_average(train_batch_losses, 50)
433
+ val_ma = moving_average(val_batch_losses, 50)
434
+
435
+ # Subplot 1: Training loss per batch
436
+ ax1.plot(train_batch_losses, alpha=0.3, color='lightblue', linewidth=0.5, label='Raw Training Loss')
437
+ ax1.plot(range(len(train_ma)), train_ma, color='blue', linewidth=2, label='Smoothed (Moving Avg, window=50)')
438
+ ax1.set_xlabel('Training Batch', fontsize=11)
439
+ ax1.set_ylabel('Loss', fontsize=11)
440
+ ax1.set_title('Training Loss per Batch (Detailed View)', fontsize=13, fontweight='bold')
441
+ ax1.legend(fontsize=9)
442
+ ax1.grid(True, alpha=0.3)
443
+
444
+ # Add vertical lines for epoch boundaries
445
+ batches_per_epoch = len(train_loader)
446
+ for epoch_idx in range(1, EPOCHS):
447
+ ax1.axvline(x=epoch_idx * batches_per_epoch, color='red', linestyle='--', linewidth=1, alpha=0.5)
448
+
449
+ # Subplot 2: Validation loss per batch
450
+ ax2.plot(val_batch_losses, alpha=0.3, color='lightcoral', linewidth=0.5, label='Raw Validation Loss')
451
+ ax2.plot(range(len(val_ma)), val_ma, color='red', linewidth=2, label='Smoothed (Moving Avg, window=50)')
452
+ ax2.set_xlabel('Validation Batch', fontsize=11)
453
+ ax2.set_ylabel('Loss', fontsize=11)
454
+ ax2.set_title('Validation Loss per Batch (Detailed View)', fontsize=13, fontweight='bold')
455
+ ax2.legend(fontsize=9)
456
+ ax2.grid(True, alpha=0.3)
457
+
458
+ # Add vertical lines for epoch boundaries
459
+ val_batches_per_epoch = len(val_loader)
460
+ for epoch_idx in range(1, EPOCHS):
461
+ ax2.axvline(x=epoch_idx * val_batches_per_epoch, color='blue', linestyle='--', linewidth=1, alpha=0.5)
462
+
463
+ plt.tight_layout()
464
+
465
+ # Save the detailed plot
466
+ detailed_plot_path = 'training_loss_per_batch_detailed_service.png'
467
+ plt.savefig(detailed_plot_path, dpi=300, bbox_inches='tight')
468
+ print(f"Detailed per-batch loss plot saved to: {detailed_plot_path}")
469
+
470
+ # Print batch loss statistics
471
+ print("\nBatch Loss Statistics:")
472
+ print("-" * 60)
473
+ print(f"Training batches: {len(train_batch_losses)}")
474
+ print(f" Min loss: {min(train_batch_losses):.4f}")
475
+ print(f" Max loss: {max(train_batch_losses):.4f}")
476
+ print(f" Mean loss: {np.mean(train_batch_losses):.4f}")
477
+ print(f" Std dev: {np.std(train_batch_losses):.4f}")
478
+ print(f"\nValidation batches: {len(val_batch_losses)}")
479
+ print(f" Min loss: {min(val_batch_losses):.4f}")
480
+ print(f" Max loss: {max(val_batch_losses):.4f}")
481
+ print(f" Mean loss: {np.mean(val_batch_losses):.4f}")
482
+ print(f" Std dev: {np.std(val_batch_losses):.4f}")
483
+ print("-" * 60)
484
+
485
+ # VALIDATION SET EVALUATION (WITH OPTIMIZED THRESHOLDS)
486
+ print("\n" + "="*60)
487
+ print("VALIDATION SET EVALUATION (WITH OPTIMIZED THRESHOLDS)")
488
+ print("="*60)
489
+
490
+ val_preds, val_labels_eval = predict_with_thresholds(model, val_loader, optimal_thresholds, device)
491
+
492
+ # Also get predictions with default threshold for comparison
493
+ model.eval()
494
+ val_preds_default = []
495
+ with torch.no_grad():
496
+ for input_ids, attention_mask, labels in val_loader:
497
+ input_ids = input_ids.to(device)
498
+ attention_mask = attention_mask.to(device)
499
+ logits = model(input_ids, attention_mask)
500
+ probs = torch.sigmoid(logits).cpu().numpy()
501
+ preds = (probs > 0.5).astype(int)
502
+ val_preds_default.append(preds)
503
+
504
+ val_preds_default = np.vstack(val_preds_default)
505
+
506
+ print(f"\nPredicted data shape: {val_preds.shape}")
507
+ print(f"Ground truth data shape: {val_labels_eval.shape}")
508
+
509
+ # Comparison: Default vs Optimized Thresholds
510
+ print("\n" + "="*60)
511
+ print("COMPARISON: Default vs Optimized Thresholds")
512
+ print("="*60)
513
+
514
+ print("\nDefault Threshold (0.5):")
515
+ for i, label in enumerate(label_cols):
516
+ f1_default = f1_score(val_labels_eval[:, i], val_preds_default[:, i], zero_division=0)
517
+ print(f" {label}: F1 = {f1_default:.4f}")
518
+
519
+ print("\nOptimized Thresholds:")
520
+ for i, label in enumerate(label_cols):
521
+ f1_optimized = f1_score(val_labels_eval[:, i], val_preds[:, i], zero_division=0)
522
+ print(f" {label}: F1 = {f1_optimized:.4f} (threshold = {optimal_thresholds[i]:.2f})")
523
+ print("="*60 + "\n")
524
+
525
+ # Classification Report
526
+ print('\n' + '='*60)
527
+ print('CLASSIFICATION REPORT (VALIDATION)')
528
+ print('='*60)
529
+ print(classification_report(val_labels_eval, val_preds, target_names=label_cols))
530
+
531
+ # Hamming Loss
532
+ val_hamming_loss = hamming_loss(val_labels_eval, val_preds)
533
+ print("="*60)
534
+ print("HAMMING LOSS (Multi-label Error Rate)")
535
+ print("="*60)
536
+ print(f"Hamming Loss: {val_hamming_loss:.4f}")
537
+ print(f"(Fraction of incorrectly predicted labels: {val_hamming_loss:.2%})")
538
+
539
+ # Per-aspect metrics
540
+ print("\n" + "="*60)
541
+ print("PER-ASPECT METRICS (VALIDATION)")
542
+ print("="*60)
543
+
544
+ for i, aspect in enumerate(label_cols):
545
+ y_true = val_labels_eval[:, i]
546
+ y_pred = val_preds[:, i]
547
+
548
+ acc = accuracy_score(y_true, y_pred)
549
+ prec = precision_score(y_true, y_pred, zero_division=0)
550
+ rec = recall_score(y_true, y_pred, zero_division=0)
551
+ f1 = f1_score(y_true, y_pred, zero_division=0)
552
+
553
+ print(f"\n=== {aspect.upper()} ===")
554
+ print(f"Accuracy: {acc:.4f}")
555
+ print(f"Precision: {prec:.4f}")
556
+ print(f"Recall: {rec:.4f}")
557
+ print(f"F1 Score: {f1:.4f}")
558
+
559
+ tp = np.sum((y_true == 1) & (y_pred == 1))
560
+ tn = np.sum((y_true == 0) & (y_pred == 0))
561
+ fp = np.sum((y_true == 0) & (y_pred == 1))
562
+ fn = np.sum((y_true == 1) & (y_pred == 0))
563
+
564
+ print(f" TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
565
+
566
+ # Exact match accuracy
567
+ val_exact_matches = np.all(val_preds == val_labels_eval, axis=1)
568
+ val_exact_match_acc = np.mean(val_exact_matches)
569
+
570
+ print("\n" + "="*60)
571
+ print("EXACT MATCH (ALL ASPECTS)")
572
+ print("="*60)
573
+ print(f"Samples with ALL aspects correct: {np.sum(val_exact_matches)}/{len(val_exact_matches)}")
574
+ print(f"Exact Match Accuracy: {val_exact_match_acc:.4f}")
575
+
576
+ # Partial match accuracy (per sample)
577
+ partial_match_scores = []
578
+ for i in range(len(val_labels_eval)):
579
+ correct_labels = np.sum(val_preds[i] == val_labels_eval[i])
580
+ partial_match_scores.append(correct_labels / len(label_cols))
581
+
582
+ partial_match_scores = np.array(partial_match_scores)
583
+ avg_partial_match = np.mean(partial_match_scores)
584
+
585
+ print("\n" + "="*60)
586
+ print("PARTIAL MATCH (PER-SAMPLE LABEL ACCURACY)")
587
+ print("="*60)
588
+ print(f"Average Partial Match: {avg_partial_match:.4f} ({avg_partial_match:.2%})")
589
+ print(f"(Average fraction of labels correctly predicted per sample)")
590
+
591
+ # Sample predictions with match/mismatch
592
+ print("\n" + "="*60)
593
+ print("SAMPLE PREDICTIONS VS GROUND TRUTH (VALIDATION)")
594
+ print("="*60)
595
+
596
+ num_samples = min(10, len(val_X))
597
+ print(f"\nShowing {num_samples} validation samples:\n")
598
+
599
+ for idx in range(num_samples):
600
+ review = val_X[idx]
601
+ true_labels = [label_cols[i] for i, v in enumerate(val_labels_eval[idx]) if v == 1]
602
+ pred_labels = [label_cols[i] for i, v in enumerate(val_preds[idx]) if v == 1]
603
+
604
+ # Calculate partial match for this sample
605
+ # Count how many true labels were correctly predicted
606
+ matching_labels = len(set(true_labels) & set(pred_labels))
607
+ total_true_labels = len(true_labels) if len(true_labels) > 0 else 1
608
+ partial_match = matching_labels / total_true_labels
609
+
610
+ review_display = review[:150] + "..." if len(review) > 150 else review
611
+ print(f"Sample {idx + 1}:")
612
+ print(f"Review: {review_display}")
613
+ print(f"✓ True Labels: {true_labels if true_labels else ['None']}")
614
+ print(f"→ Predicted Labels: {pred_labels if pred_labels else ['None']}")
615
+ print(f"Match: {'✓ Exact' if set(true_labels) == set(pred_labels) else '✗ Mismatch'}")
616
+ print(f"Partial Match: {matching_labels}/{total_true_labels} labels correct ({partial_match:.2%})")
617
+ print("-" * 40)
618
+
619
+ # Final Evaluation on Test Set (WITH OPTIMIZED THRESHOLDS)
620
+ print("\n" + "="*60)
621
+ print("FINAL EVALUATION ON TEST SET (WITH OPTIMIZED THRESHOLDS)")
622
+ print("="*60)
623
+
624
+ all_preds, all_labels = predict_with_thresholds(model, test_loader, optimal_thresholds, device)
625
+
626
+ print(f"\nPredicted data shape: {all_preds.shape}")
627
+ print(f"Ground truth data shape: {all_labels.shape}")
628
+
629
+ # Classification Report
630
+ print('\n' + '='*60)
631
+ print('CLASSIFICATION REPORT')
632
+ print('='*60)
633
+ print(classification_report(all_labels, all_preds, target_names=label_cols))
634
+
635
+ # Hamming Loss
636
+ hamming_loss_value = hamming_loss(all_labels, all_preds)
637
+ print("="*60)
638
+ print("HAMMING LOSS (Multi-label Error Rate)")
639
+ print("="*60)
640
+ print(f"Hamming Loss: {hamming_loss_value:.4f}")
641
+ print(f"(Fraction of incorrectly predicted labels: {hamming_loss_value:.2%})")
642
+
643
+ # Per-aspect metrics
644
+ print("\n" + "="*60)
645
+ print("PER-ASPECT METRICS")
646
+ print("="*60)
647
+
648
+ for i, aspect in enumerate(label_cols):
649
+ y_true = all_labels[:, i]
650
+ y_pred = all_preds[:, i]
651
+
652
+ acc = accuracy_score(y_true, y_pred)
653
+ prec = precision_score(y_true, y_pred, zero_division=0)
654
+ rec = recall_score(y_true, y_pred, zero_division=0)
655
+ f1 = f1_score(y_true, y_pred, zero_division=0)
656
+
657
+ print(f"\n=== {aspect.upper()} ===")
658
+ print(f"Accuracy: {acc:.4f}")
659
+ print(f"Precision: {prec:.4f}")
660
+ print(f"Recall: {rec:.4f}")
661
+ print(f"F1 Score: {f1:.4f}")
662
+
663
+ tp = np.sum((y_true == 1) & (y_pred == 1))
664
+ tn = np.sum((y_true == 0) & (y_pred == 0))
665
+ fp = np.sum((y_true == 0) & (y_pred == 1))
666
+ fn = np.sum((y_true == 1) & (y_pred == 0))
667
+
668
+ print(f" TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
669
+
670
+ # Exact match accuracy
671
+ exact_matches = np.all(all_preds == all_labels, axis=1)
672
+ exact_match_acc = np.mean(exact_matches)
673
+
674
+ print("\n" + "="*60)
675
+ print("EXACT MATCH (ALL ASPECTS)")
676
+ print("="*60)
677
+ print(f"Samples with ALL aspects correct: {np.sum(exact_matches)}/{len(exact_matches)}")
678
+ print(f"Exact Match Accuracy: {exact_match_acc:.4f}")
679
+
680
+ # Partial match accuracy (per sample)
681
+ test_partial_match_scores = []
682
+ for i in range(len(all_labels)):
683
+ correct_labels = np.sum(all_preds[i] == all_labels[i])
684
+ test_partial_match_scores.append(correct_labels / len(label_cols))
685
+
686
+ test_partial_match_scores = np.array(test_partial_match_scores)
687
+ avg_test_partial_match = np.mean(test_partial_match_scores)
688
+
689
+ print("\n" + "="*60)
690
+ print("PARTIAL MATCH (PER-SAMPLE LABEL ACCURACY)")
691
+ print("="*60)
692
+ print(f"Average Partial Match: {avg_test_partial_match:.4f} ({avg_test_partial_match:.2%})")
693
+ print(f"(Average fraction of labels correctly predicted per sample)")
694
+
695
+ # Sample predictions
696
+ print("\n" + "="*60)
697
+ print("SAMPLE PREDICTIONS VS GROUND TRUTH")
698
+ print("="*60)
699
+
700
+ num_samples = min(10, len(test_X))
701
+ print(f"\nShowing {num_samples} test samples:\n")
702
+
703
+ for idx in range(num_samples):
704
+ review = test_X[idx]
705
+ true_labels = [label_cols[i] for i, v in enumerate(all_labels[idx]) if v == 1]
706
+ pred_labels = [label_cols[i] for i, v in enumerate(all_preds[idx]) if v == 1]
707
+
708
+ # Calculate partial match for this sample
709
+ # Count how many true labels were correctly predicted
710
+ matching_labels = len(set(true_labels) & set(pred_labels))
711
+ total_true_labels = len(true_labels) if len(true_labels) > 0 else 1
712
+ partial_match = matching_labels / total_true_labels
713
+
714
+ review_display = review[:150] + "..." if len(review) > 150 else review
715
+ print(f"Sample {idx + 1}:")
716
+ print(f"Review: {review_display}")
717
+ print(f"✓ True Labels: {true_labels if true_labels else ['None']}")
718
+ print(f"→ Predicted Labels: {pred_labels if pred_labels else ['None']}")
719
+ print(f"Match: {'✓ Exact' if set(true_labels) == set(pred_labels) else '✗ Mismatch'}")
720
+ print(f"Partial Match: {matching_labels}/{total_true_labels} labels correct ({partial_match:.2%})")
721
+ print("-" * 40)
722
+
723
+ # Save model interactively (optional)
724
+ # model_save_path = 'gemma_service_classifier.pth'
725
+ # torch.save({
726
+ # 'epoch': EPOCHS,
727
+ # 'model_state_dict': model.state_dict(),
728
+ # 'optimizer_state_dict': optimizer.state_dict(),
729
+ # 'train_loss': avg_train_loss,
730
+ # 'test_loss': avg_test_loss,
731
+ # }, model_save_path)
732
+ # print(f"Model saved to {model_save_path}")
733
+ model_save_path = os.path.join(SAVE_DIR, 'gemma_service_classifier.pth')
734
+ torch.save({
735
+ 'epoch': best_epoch if best_model_state is not None else EPOCHS,
736
+ 'model_state_dict': model.state_dict(),
737
+ 'optimizer_state_dict': optimizer.state_dict(),
738
+ 'train_loss': train_losses[best_epoch - 1] if best_model_state is not None else train_losses[-1] if train_losses else 0,
739
+ 'val_loss': best_val_loss if best_model_state is not None else (val_losses[-1] if val_losses else 0),
740
+ 'best_epoch': best_epoch,
741
+ 'best_val_loss': best_val_loss,
742
+ 'optimal_thresholds': optimal_thresholds,
743
+ }, model_save_path)
744
+ print(f"Model saved to {model_save_path}")
745
+
746
+ print("\n" + "="*60)
747
+ print("TRAINING COMPLETE")
748
+ print("="*60)