Diwanshuydv commited on
Commit
f5c51a4
·
verified ·
1 Parent(s): dc18e48

Upload 2 files

Browse files
Files changed (2) hide show
  1. Hugging_FaceA.py +276 -0
  2. best_resnet18_stl10.pth +3 -0
Hugging_FaceA.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torchvision import models, transforms
9
+ from datasets import load_dataset
10
+ import wandb
11
+ from huggingface_hub import HfApi, hf_hub_download
12
+ from sklearn.metrics import confusion_matrix, classification_report
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+ # 1. Custom Dataset implementation
18
+ class STL10SubsetDataset(Dataset):
19
+ def __init__(self, hf_dataset, transform=None):
20
+ self.dataset = hf_dataset
21
+ self.transform = transform
22
+
23
+ def __len__(self):
24
+ return len(self.dataset)
25
+
26
+ def __getitem__(self, idx):
27
+ item = self.dataset[idx]
28
+ image = item['image']
29
+ label = item['label']
30
+
31
+ # Ensure image is RGB
32
+ if image.mode != 'RGB':
33
+ image = image.convert('RGB')
34
+
35
+ if self.transform:
36
+ image = self.transform(image)
37
+
38
+ return image, label
39
+
40
+ def get_transforms():
41
+ # ResNet-18 expects 224x224 images, normalized via ImageNet stats
42
+ train_transform = transforms.Compose([
43
+ transforms.Resize((224, 224)),
44
+ transforms.RandomHorizontalFlip(),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
47
+ ])
48
+
49
+ val_transform = transforms.Compose([
50
+ transforms.Resize((224, 224)),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
53
+ ])
54
+
55
+ return train_transform, val_transform
56
+
57
+ def train_one_epoch(model, loader, criterion, optimizer, device):
58
+ model.train()
59
+ running_loss = 0.0
60
+ correct = 0
61
+ total = 0
62
+
63
+ for inputs, labels in loader:
64
+ inputs, labels = inputs.to(device), labels.to(device)
65
+
66
+ optimizer.zero_grad()
67
+ outputs = model(inputs)
68
+ loss = criterion(outputs, labels)
69
+ loss.backward()
70
+ optimizer.step()
71
+
72
+ running_loss += loss.item() * inputs.size(0)
73
+ _, predicted = outputs.max(1)
74
+ total += labels.size(0)
75
+ correct += predicted.eq(labels).sum().item()
76
+
77
+ epoch_loss = running_loss / total
78
+ epoch_acc = correct / total
79
+ return epoch_loss, epoch_acc
80
+
81
+ def evaluate(model, loader, criterion, device):
82
+ model.eval()
83
+ running_loss = 0.0
84
+ correct = 0
85
+ total = 0
86
+ all_preds = []
87
+ all_labels = []
88
+
89
+ with torch.no_grad():
90
+ for inputs, labels in loader:
91
+ inputs, labels = inputs.to(device), labels.to(device)
92
+ outputs = model(inputs)
93
+ loss = criterion(outputs, labels)
94
+
95
+ running_loss += loss.item() * inputs.size(0)
96
+ _, predicted = outputs.max(1)
97
+ total += labels.size(0)
98
+ correct += predicted.eq(labels).sum().item()
99
+
100
+ all_preds.extend(predicted.cpu().numpy())
101
+ all_labels.extend(labels.cpu().numpy())
102
+
103
+ epoch_loss = running_loss / total
104
+ epoch_acc = correct / total
105
+ return epoch_loss, epoch_acc, all_preds, all_labels
106
+
107
+ def main():
108
+ parser = argparse.ArgumentParser(description="STL-10 ResNet-18 Training Pipeline")
109
+ parser.add_argument("--hf_repo_id", type=str, default="diwanshuydv/mlops_minor", help="Hugging Face model repo ID")
110
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
111
+ parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
112
+ parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
113
+ args = parser.parse_args()
114
+
115
+ # Initialize weights and biases
116
+ wandb.init(project="stl10-resnet18-assignment", config=vars(args))
117
+
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+ print(f"Using device: {device}")
120
+
121
+ # 1 & 2. Load dataset and create DataLoaders
122
+ print("Loading dataset...")
123
+ # Using 'train' and 'test' splits if available. We will split train into train/val if needed,
124
+ # or just use test as val for simplicity if it's a small subset.
125
+ dataset = load_dataset("Chiranjeev007/STL-10_Subset")
126
+
127
+ # Check what splits are available
128
+ print("Available splits:", dataset.keys())
129
+
130
+ # Assuming 'train' and 'test' exist. Let's create datasets.
131
+ train_transform, val_transform = get_transforms()
132
+
133
+ # Extract labels to know number of classes. STL-10 has 10 classes.
134
+ num_classes = 10
135
+ class_names = [f"Class_{i}" for i in range(num_classes)] # Fallback names if not in dataset
136
+ if 'train' in dataset and hasattr(dataset['train'].features['label'], 'names'):
137
+ class_names = dataset['train'].features['label'].names
138
+
139
+ train_dataset = STL10SubsetDataset(dataset['train'], transform=train_transform)
140
+ val_dataset = STL10SubsetDataset(dataset['test'], transform=val_transform) # Using test as val during training
141
+ test_dataset = STL10SubsetDataset(dataset['test'], transform=val_transform) # Same for test
142
+
143
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
144
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
145
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
146
+
147
+ # 3. Load ResNet-18 and adapt for num_classes
148
+ print("Initializing ResNet-18...")
149
+ model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
150
+ num_ftrs = model.fc.in_features
151
+ model.fc = nn.Linear(num_ftrs, num_classes)
152
+ model = model.to(device)
153
+
154
+ criterion = nn.CrossEntropyLoss()
155
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
156
+
157
+ # 4. Training Loop and WandB Logging
158
+ best_val_acc = 0.0
159
+ best_model_path = "best_resnet18_stl10.pth"
160
+
161
+ print("Starting training...")
162
+ for epoch in range(args.epochs):
163
+ train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
164
+ val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, device)
165
+
166
+ print(f"Epoch [{epoch+1}/{args.epochs}] Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
167
+
168
+ wandb.log({
169
+ "epoch": epoch + 1,
170
+ "train/loss": train_loss,
171
+ "train/accuracy": train_acc,
172
+ "val/loss": val_loss,
173
+ "val/accuracy": val_acc
174
+ })
175
+
176
+ if val_acc > best_val_acc:
177
+ best_val_acc = val_acc
178
+ torch.save(model.state_dict(), best_model_path)
179
+ print(f"--> Saved new best model with Val Acc: {best_val_acc:.4f}")
180
+
181
+ # 5. Push best model to Hugging Face
182
+ print(f"Pushing model to Hugging Face Hub: {args.hf_repo_id}")
183
+ try:
184
+ api = HfApi()
185
+ # Create repo if it doesn't exist
186
+ api.create_repo(repo_id=args.hf_repo_id, exist_ok=True)
187
+ api.upload_file(
188
+ path_or_fileobj=best_model_path,
189
+ path_in_repo="pytorch_model.bin",
190
+ repo_id=args.hf_repo_id
191
+ )
192
+ print("Successfully pushed to HF.")
193
+ except Exception as e:
194
+ print(f"Failed to push to huggingface: {e}")
195
+ print("Continuing with local evaluation...")
196
+
197
+ # 6. Load model from Hugging Face for evaluation steps
198
+ print("Downloading model from Hugging Face Hub for evaluation...")
199
+ eval_model = models.resnet18(weights=None)
200
+ eval_model.fc = nn.Linear(num_ftrs, num_classes)
201
+
202
+ try:
203
+ downloaded_model_path = hf_hub_download(repo_id=args.hf_repo_id, filename="pytorch_model.bin")
204
+ eval_model.load_state_dict(torch.load(downloaded_model_path, map_location=device))
205
+ print("Loaded model from HF Hub.")
206
+ except Exception as e:
207
+ print(f"Could not download from HF: {e}. Falling back to local best model.")
208
+ eval_model.load_state_dict(torch.load(best_model_path, map_location=device))
209
+
210
+ eval_model = eval_model.to(device)
211
+
212
+ # Run evaluation on test set
213
+ print("Running final evaluation on test set...")
214
+ _, test_acc, test_preds, test_labels = evaluate(eval_model, test_loader, criterion, device)
215
+ print(f"Test Accuracy: {test_acc:.4f}")
216
+
217
+ # 7. Confusion Matrix
218
+ print("Generating Confusion Matrix...")
219
+ wandb.log({
220
+ "confusion_matrix": wandb.plot.confusion_matrix(
221
+ probs=None,
222
+ y_true=test_labels,
223
+ preds=test_preds,
224
+ class_names=class_names
225
+ )
226
+ })
227
+
228
+ # 8. Class-wise accuracy bar plot
229
+ print("Generating Class-wise accuracy plot...")
230
+ report = classification_report(test_labels, test_preds, target_names=class_names, output_dict=True)
231
+ # Extract just class accuracies (f1-score is often used, but we can compute exact accuracy from conf matrix)
232
+ cm = confusion_matrix(test_labels, test_preds)
233
+ class_accuracies = cm.diagonal() / cm.sum(axis=1)
234
+
235
+ data = [[class_names[i], acc] for i, acc in enumerate(class_accuracies)]
236
+ table = wandb.Table(data=data, columns=["Class", "Accuracy"])
237
+ wandb.log({"class_accuracy": wandb.plot.bar(table, "Class", "Accuracy", title="Class-wise Accuracy")})
238
+
239
+ # 9. Log 20 examples with image, predicted, and actual
240
+ print("Logging 20 examples to WandB...")
241
+ # We need the raw images, not normalized tensors natively, so let's get them from dataset
242
+ indices = random.sample(range(len(dataset['test'])), min(20, len(dataset['test'])))
243
+
244
+ example_data = []
245
+
246
+ eval_model.eval()
247
+ with torch.no_grad():
248
+ for idx in indices:
249
+ item = dataset['test'][idx]
250
+ raw_image = item['image']
251
+ if raw_image.mode != 'RGB':
252
+ raw_image = raw_image.convert('RGB')
253
+ actual_label_idx = item['label']
254
+ actual_label_str = class_names[actual_label_idx]
255
+
256
+ # transform for model
257
+ tensor_img = val_transform(raw_image).unsqueeze(0).to(device)
258
+ out = eval_model(tensor_img)
259
+ _, pred_idx = out.max(1)
260
+ pred_idx = pred_idx.item()
261
+ pred_label_str = class_names[pred_idx]
262
+
263
+ example_data.append([
264
+ wandb.Image(raw_image),
265
+ pred_label_str,
266
+ actual_label_str
267
+ ])
268
+
269
+ examples_table = wandb.Table(data=example_data, columns=["Image", "Predicted", "Actual"])
270
+ wandb.log({"test_examples": examples_table})
271
+
272
+ print("Done!")
273
+ wandb.finish()
274
+
275
+ if __name__ == "__main__":
276
+ main()
best_resnet18_stl10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ed84664e12672d3a22f61d272b608e332c2be08ed4c95090162af7890af2743
3
+ size 44807307