Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- README.md +210 -14
- best_model.pth +3 -0
- config.json +14 -0
- main.py +193 -0
- requirements.txt +27 -0
README.md
CHANGED
|
@@ -1,14 +1,210 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- image-classification
|
| 5 |
+
- pytorch
|
| 6 |
+
- resnet
|
| 7 |
+
- lora
|
| 8 |
+
- computer-vision
|
| 9 |
+
- smoking-detection
|
| 10 |
+
datasets:
|
| 11 |
+
- sujaykapadnis/smoking
|
| 12 |
+
metrics:
|
| 13 |
+
- accuracy
|
| 14 |
+
- f1
|
| 15 |
+
library_name: pytorch
|
| 16 |
+
pipeline_tag: image-classification
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# Smoker Detection with LoRA Fine-Tuning
|
| 20 |
+
|
| 21 |
+
Fine-tuned ResNet34 model using LoRA (Low-Rank Adaptation) for binary smoking detection in images.
|
| 22 |
+
|
| 23 |
+
## Model Description
|
| 24 |
+
|
| 25 |
+
This model uses parameter-efficient fine-tuning with LoRA on a pretrained ResNet34 to classify images as "Smoker" or "Non-Smoker". By training only 2.14% of parameters, it achieves 89.73% test accuracy while preserving ImageNet knowledge.
|
| 26 |
+
|
| 27 |
+
- **Model Type:** ResNet34 + LoRA adapters
|
| 28 |
+
- **Task:** Binary Image Classification
|
| 29 |
+
- **Framework:** PyTorch
|
| 30 |
+
- **License:** MIT
|
| 31 |
+
|
| 32 |
+
## Performance
|
| 33 |
+
|
| 34 |
+
| Split | Accuracy | F1-Score (Smoking) |
|
| 35 |
+
|-------|----------|-------------------|
|
| 36 |
+
| Validation | 94.44% | - |
|
| 37 |
+
| Test | 89.73% | 89.96% |
|
| 38 |
+
|
| 39 |
+
**Efficiency:**
|
| 40 |
+
- Trainable parameters: 465K (2.14% of model)
|
| 41 |
+
- Training time: ~15 minutes on Kaggle T4 GPU
|
| 42 |
+
|
| 43 |
+
## Usage
|
| 44 |
+
|
| 45 |
+
### Installation
|
| 46 |
+
```bash
|
| 47 |
+
pip install torch torchvision pillow
|
| 48 |
+
Load Model
|
| 49 |
+
pythonimport torch
|
| 50 |
+
import torch.nn as nn
|
| 51 |
+
from torchvision import models
|
| 52 |
+
from torchvision.models import ResNet34_Weights
|
| 53 |
+
from PIL import Image
|
| 54 |
+
import torchvision.transforms as transforms
|
| 55 |
+
|
| 56 |
+
# Define LoRA Layer
|
| 57 |
+
class LoRALayer(nn.Module):
|
| 58 |
+
def __init__(self, original_layer, rank=8):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.original_layer = original_layer
|
| 61 |
+
self.rank = rank
|
| 62 |
+
|
| 63 |
+
out_channels = original_layer.out_channels
|
| 64 |
+
in_channels = original_layer.in_channels
|
| 65 |
+
kernel_size = original_layer.kernel_size
|
| 66 |
+
|
| 67 |
+
self.lora_A = nn.Parameter(
|
| 68 |
+
torch.randn(rank, in_channels, *kernel_size) * 0.01
|
| 69 |
+
)
|
| 70 |
+
self.lora_B = nn.Parameter(
|
| 71 |
+
torch.zeros(out_channels, rank, 1, 1)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.original_layer.weight.requires_grad = False
|
| 75 |
+
if self.original_layer.bias is not None:
|
| 76 |
+
self.original_layer.bias.requires_grad = False
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
original_output = self.original_layer(x)
|
| 80 |
+
lora_output = nn.functional.conv2d(
|
| 81 |
+
x, self.lora_A,
|
| 82 |
+
stride=self.original_layer.stride,
|
| 83 |
+
padding=self.original_layer.padding
|
| 84 |
+
)
|
| 85 |
+
lora_output = nn.functional.conv2d(lora_output, self.lora_B)
|
| 86 |
+
return original_output + lora_output
|
| 87 |
+
|
| 88 |
+
def apply_lora_to_model(model, rank=8):
|
| 89 |
+
for param in model.parameters():
|
| 90 |
+
param.requires_grad = False
|
| 91 |
+
|
| 92 |
+
for param in model.fc.parameters():
|
| 93 |
+
param.requires_grad = True
|
| 94 |
+
|
| 95 |
+
for block in model.layer3:
|
| 96 |
+
if hasattr(block, 'conv1'):
|
| 97 |
+
block.conv1 = LoRALayer(block.conv1, rank=rank)
|
| 98 |
+
if hasattr(block, 'conv2'):
|
| 99 |
+
block.conv2 = LoRALayer(block.conv2, rank=rank)
|
| 100 |
+
|
| 101 |
+
for block in model.layer4:
|
| 102 |
+
if hasattr(block, 'conv1'):
|
| 103 |
+
block.conv1 = LoRALayer(block.conv1, rank=rank)
|
| 104 |
+
if hasattr(block, 'conv2'):
|
| 105 |
+
block.conv2 = LoRALayer(block.conv2, rank=rank)
|
| 106 |
+
|
| 107 |
+
return model
|
| 108 |
+
|
| 109 |
+
# Load model
|
| 110 |
+
model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
|
| 111 |
+
model.fc = nn.Linear(model.fc.in_features, 2)
|
| 112 |
+
model = apply_lora_to_model(model, rank=8)
|
| 113 |
+
|
| 114 |
+
# Load trained weights
|
| 115 |
+
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
|
| 116 |
+
model.eval()
|
| 117 |
+
|
| 118 |
+
# Preprocessing
|
| 119 |
+
transform = transforms.Compose([
|
| 120 |
+
transforms.Resize((224, 224)),
|
| 121 |
+
transforms.ToTensor(),
|
| 122 |
+
transforms.Normalize(
|
| 123 |
+
mean=[0.485, 0.456, 0.406],
|
| 124 |
+
std=[0.229, 0.224, 0.225]
|
| 125 |
+
)
|
| 126 |
+
])
|
| 127 |
+
|
| 128 |
+
# Inference
|
| 129 |
+
def predict(image_path):
|
| 130 |
+
image = Image.open(image_path).convert('RGB')
|
| 131 |
+
image_tensor = transform(image).unsqueeze(0)
|
| 132 |
+
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
outputs = model(image_tensor)
|
| 135 |
+
probs = torch.softmax(outputs, dim=1)
|
| 136 |
+
confidence, predicted = torch.max(probs, 1)
|
| 137 |
+
|
| 138 |
+
classes = ['Non-Smoker', 'Smoker']
|
| 139 |
+
return classes[predicted.item()], confidence.item() * 100
|
| 140 |
+
|
| 141 |
+
# Example
|
| 142 |
+
prediction, confidence = predict('image.jpg')
|
| 143 |
+
print(f"{prediction} ({confidence:.1f}% confidence)")
|
| 144 |
+
Training Details
|
| 145 |
+
Dataset: 1,120 images from Kaggle Smoking Detection Dataset
|
| 146 |
+
|
| 147 |
+
Training: 716 images (64%)
|
| 148 |
+
Validation: 180 images (16%)
|
| 149 |
+
Test: 224 images (20%)
|
| 150 |
+
|
| 151 |
+
Hyperparameters:
|
| 152 |
+
|
| 153 |
+
Learning Rate: 1e-4
|
| 154 |
+
Optimizer: AdamW (weight decay: 1e-4)
|
| 155 |
+
Batch Size: 32
|
| 156 |
+
Epochs: 15
|
| 157 |
+
LoRA Rank: 8
|
| 158 |
+
|
| 159 |
+
Data Augmentation:
|
| 160 |
+
|
| 161 |
+
Random horizontal flip (p=0.5)
|
| 162 |
+
Random rotation (Β±10Β°)
|
| 163 |
+
Color jitter (brightness, contrast, saturation)
|
| 164 |
+
|
| 165 |
+
What is LoRA?
|
| 166 |
+
LoRA (Low-Rank Adaptation) adds small trainable matrices to frozen pretrained weights:
|
| 167 |
+
Output = W_frozen Γ input + (B Γ A) Γ input
|
| 168 |
+
Where A and B are low-rank matrices (rank=8), adding only 2.14% trainable parameters while maintaining model capacity.
|
| 169 |
+
Benefits:
|
| 170 |
+
|
| 171 |
+
Prevents overfitting on small datasets
|
| 172 |
+
Preserves pretrained ImageNet features
|
| 173 |
+
Faster training and lower memory usage
|
| 174 |
+
Easier deployment (smaller checkpoint files)
|
| 175 |
+
|
| 176 |
+
Model Architecture
|
| 177 |
+
ResNet34 (21.7M parameters)
|
| 178 |
+
βββ Frozen Layers (21.3M - 97.86%)
|
| 179 |
+
β βββ conv1, layer1, layer2
|
| 180 |
+
β βββ Pretrained ImageNet weights
|
| 181 |
+
βββ Trainable Layers (465K - 2.14%)
|
| 182 |
+
βββ LoRA adapters on layer3 (6 blocks)
|
| 183 |
+
βββ LoRA adapters on layer4 (3 blocks)
|
| 184 |
+
βββ Classification head fc (512 β 2)
|
| 185 |
+
Limitations
|
| 186 |
+
|
| 187 |
+
Trained on limited dataset (1,120 images)
|
| 188 |
+
Low resolution images (250Γ250)
|
| 189 |
+
May not generalize to all smoking scenarios
|
| 190 |
+
Best for frontal/profile views with visible cigarettes
|
| 191 |
+
|
| 192 |
+
Citation
|
| 193 |
+
bibtex@misc{smoker-detection-lora,
|
| 194 |
+
author = {Noel Triguero},
|
| 195 |
+
title = {Smoker Detection with LoRA Fine-Tuning},
|
| 196 |
+
year = {2025},
|
| 197 |
+
publisher = {Hugging Face},
|
| 198 |
+
howpublished = {\url{https://huggingface.co/notrito/smoker-detection}}
|
| 199 |
+
}
|
| 200 |
+
References
|
| 201 |
+
|
| 202 |
+
LoRA Paper - Hu et al., 2021
|
| 203 |
+
Dataset - Sujay Kapadnis
|
| 204 |
+
Training Notebook
|
| 205 |
+
|
| 206 |
+
Contact
|
| 207 |
+
|
| 208 |
+
Author: Noel Triguero
|
| 209 |
+
Email: noel.triguero@gmail.com
|
| 210 |
+
Kaggle: notrito
|
best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b1c77abdd212805b39faf1ec66bdb382a1c4bdf141ccf7d621cf1e811baba54
|
| 3 |
+
size 87148476
|
config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "resnet34-lora",
|
| 3 |
+
"architecture": "ResNet34 with LoRA adapters",
|
| 4 |
+
"task": "image-classification",
|
| 5 |
+
"num_classes": 2,
|
| 6 |
+
"class_names": ["Non-Smoker", "Smoker"],
|
| 7 |
+
"lora_config": {
|
| 8 |
+
"rank": 8,
|
| 9 |
+
"target_layers": ["layer3", "layer4"]
|
| 10 |
+
},
|
| 11 |
+
"input_size": [224, 224],
|
| 12 |
+
"pretrained_weights": "ImageNet",
|
| 13 |
+
"framework": "PyTorch"
|
| 14 |
+
}
|
main.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main training script for Smoker Detection with LoRA.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python train.py --data_path /path/to/data --epochs 15 --lr 1e-4 --rank 8
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from src.model import get_model, apply_lora_to_model, count_parameters
|
| 13 |
+
from src.dataset import create_dataloaders
|
| 14 |
+
from src.train import train_model, get_optimizer_and_criterion
|
| 15 |
+
from src.evaluate import (
|
| 16 |
+
evaluate_model,
|
| 17 |
+
print_classification_report,
|
| 18 |
+
plot_confusion_matrix,
|
| 19 |
+
plot_training_history
|
| 20 |
+
)
|
| 21 |
+
from src.utils import set_seed, get_device, create_directories, print_dataset_info
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_args():
|
| 25 |
+
"""Parse command line arguments."""
|
| 26 |
+
parser = argparse.ArgumentParser(description='Train Smoker Detection Model with LoRA')
|
| 27 |
+
|
| 28 |
+
# Data arguments
|
| 29 |
+
parser.add_argument('--data_path', type=str, default='/kaggle/input/smoking',
|
| 30 |
+
help='Path to dataset root directory')
|
| 31 |
+
|
| 32 |
+
# Model arguments
|
| 33 |
+
parser.add_argument('--rank', type=int, default=8,
|
| 34 |
+
help='LoRA rank (default: 8)')
|
| 35 |
+
parser.add_argument('--target_layers', nargs='+', default=['layer3', 'layer4'],
|
| 36 |
+
help='Layers to apply LoRA to (default: layer3 layer4)')
|
| 37 |
+
|
| 38 |
+
# Training arguments
|
| 39 |
+
parser.add_argument('--epochs', type=int, default=15,
|
| 40 |
+
help='Number of training epochs (default: 15)')
|
| 41 |
+
parser.add_argument('--batch_size', type=int, default=32,
|
| 42 |
+
help='Batch size (default: 32)')
|
| 43 |
+
parser.add_argument('--lr', type=float, default=1e-4,
|
| 44 |
+
help='Learning rate (default: 1e-4)')
|
| 45 |
+
parser.add_argument('--weight_decay', type=float, default=1e-4,
|
| 46 |
+
help='Weight decay (default: 1e-4)')
|
| 47 |
+
parser.add_argument('--img_size', type=int, default=224,
|
| 48 |
+
help='Image size (default: 224)')
|
| 49 |
+
parser.add_argument('--num_workers', type=int, default=2,
|
| 50 |
+
help='Number of data loading workers (default: 2)')
|
| 51 |
+
|
| 52 |
+
# Output arguments
|
| 53 |
+
parser.add_argument('--output_dir', type=str, default='results',
|
| 54 |
+
help='Directory to save outputs (default: results)')
|
| 55 |
+
parser.add_argument('--model_save_path', type=str, default='best_model.pth',
|
| 56 |
+
help='Path to save best model (default: best_model.pth)')
|
| 57 |
+
|
| 58 |
+
# Other arguments
|
| 59 |
+
parser.add_argument('--seed', type=int, default=42,
|
| 60 |
+
help='Random seed (default: 42)')
|
| 61 |
+
parser.add_argument('--no_cuda', action='store_true',
|
| 62 |
+
help='Disable CUDA even if available')
|
| 63 |
+
|
| 64 |
+
return parser.parse_args()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main():
|
| 68 |
+
"""Main training function."""
|
| 69 |
+
args = parse_args()
|
| 70 |
+
|
| 71 |
+
# Setup
|
| 72 |
+
print("\n" + "="*60)
|
| 73 |
+
print("π Smoker Detection Training with LoRA")
|
| 74 |
+
print("="*60 + "\n")
|
| 75 |
+
|
| 76 |
+
# Set seed for reproducibility
|
| 77 |
+
set_seed(args.seed)
|
| 78 |
+
|
| 79 |
+
# Create output directory
|
| 80 |
+
create_directories([args.output_dir])
|
| 81 |
+
|
| 82 |
+
# Get device
|
| 83 |
+
device = get_device()
|
| 84 |
+
if args.no_cuda:
|
| 85 |
+
device = torch.device('cpu')
|
| 86 |
+
print("CUDA disabled by user, using CPU")
|
| 87 |
+
|
| 88 |
+
# Data paths
|
| 89 |
+
data_path = Path(args.data_path)
|
| 90 |
+
train_path = data_path / 'Training' / 'Training'
|
| 91 |
+
val_path = data_path / 'Validation' / 'Validation'
|
| 92 |
+
test_path = data_path / 'Testing' / 'Testing'
|
| 93 |
+
|
| 94 |
+
# Create dataloaders
|
| 95 |
+
print("\nπ¦ Loading data...")
|
| 96 |
+
train_loader, val_loader, test_loader = create_dataloaders(
|
| 97 |
+
train_path=train_path,
|
| 98 |
+
val_path=val_path,
|
| 99 |
+
test_path=test_path,
|
| 100 |
+
batch_size=args.batch_size,
|
| 101 |
+
img_size=args.img_size,
|
| 102 |
+
num_workers=args.num_workers
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Print dataset info
|
| 106 |
+
print_dataset_info(train_loader, val_loader, test_loader)
|
| 107 |
+
|
| 108 |
+
# Create model
|
| 109 |
+
print("\nποΈ Building model...")
|
| 110 |
+
model = get_model(num_classes=2, pretrained=True)
|
| 111 |
+
model = model.to(device)
|
| 112 |
+
|
| 113 |
+
# Apply LoRA
|
| 114 |
+
print(f"\nπ§ Applying LoRA (rank={args.rank})...")
|
| 115 |
+
num_lora_layers = apply_lora_to_model(
|
| 116 |
+
model,
|
| 117 |
+
target_layers=args.target_layers,
|
| 118 |
+
rank=args.rank
|
| 119 |
+
)
|
| 120 |
+
print(f"β
LoRA applied to {num_lora_layers} convolutional layers")
|
| 121 |
+
|
| 122 |
+
# Count parameters
|
| 123 |
+
total_params, trainable_params, trainable_pct = count_parameters(model)
|
| 124 |
+
print(f"\nπ Parameter Count:")
|
| 125 |
+
print(f" Total: {total_params:,}")
|
| 126 |
+
print(f" Trainable: {trainable_params:,} ({trainable_pct:.2f}%)")
|
| 127 |
+
print(f" Frozen: {total_params - trainable_params:,} ({100 - trainable_pct:.2f}%)")
|
| 128 |
+
|
| 129 |
+
# Get optimizer and criterion
|
| 130 |
+
print("\nβοΈ Setting up training...")
|
| 131 |
+
optimizer, criterion = get_optimizer_and_criterion(
|
| 132 |
+
model,
|
| 133 |
+
lr=args.lr,
|
| 134 |
+
weight_decay=args.weight_decay
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Train model
|
| 138 |
+
print("\n" + "="*60)
|
| 139 |
+
history = train_model(
|
| 140 |
+
model=model,
|
| 141 |
+
train_loader=train_loader,
|
| 142 |
+
val_loader=val_loader,
|
| 143 |
+
criterion=criterion,
|
| 144 |
+
optimizer=optimizer,
|
| 145 |
+
device=device,
|
| 146 |
+
num_epochs=args.epochs,
|
| 147 |
+
save_path=args.model_save_path
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Plot training curves
|
| 151 |
+
print("\nπ Plotting training history...")
|
| 152 |
+
fig = plot_training_history(
|
| 153 |
+
history,
|
| 154 |
+
save_path=f'{args.output_dir}/training_curves.png'
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Evaluate on test set
|
| 158 |
+
print("\n" + "="*60)
|
| 159 |
+
print("π§ͺ Testing on held-out test set...")
|
| 160 |
+
print("="*60)
|
| 161 |
+
|
| 162 |
+
# Load best model
|
| 163 |
+
model.load_state_dict(torch.load(args.model_save_path))
|
| 164 |
+
|
| 165 |
+
# Get predictions
|
| 166 |
+
predictions, labels, test_acc = evaluate_model(
|
| 167 |
+
model, test_loader, device
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Print classification report
|
| 171 |
+
print_classification_report(predictions, labels)
|
| 172 |
+
|
| 173 |
+
# Plot confusion matrix
|
| 174 |
+
print("\nπ Plotting confusion matrix...")
|
| 175 |
+
fig = plot_confusion_matrix(
|
| 176 |
+
predictions,
|
| 177 |
+
labels,
|
| 178 |
+
save_path=f'{args.output_dir}/confusion_matrix.png'
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Final summary
|
| 182 |
+
print("\n" + "="*60)
|
| 183 |
+
print("β
Training Complete!")
|
| 184 |
+
print("="*60)
|
| 185 |
+
print(f"\nπ Outputs saved to: {args.output_dir}/")
|
| 186 |
+
print(f" - Training curves: {args.output_dir}/training_curves.png")
|
| 187 |
+
print(f" - Confusion matrix: {args.output_dir}/confusion_matrix.png")
|
| 188 |
+
print(f" - Best model: {args.model_save_path}")
|
| 189 |
+
print(f"\nπ― Final Test Accuracy: {test_acc:.2f}%\n")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
if __name__ == '__main__':
|
| 193 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deep Learning
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
|
| 5 |
+
# Data Processing
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
Pillow>=9.5.0
|
| 9 |
+
|
| 10 |
+
# Visualization
|
| 11 |
+
matplotlib>=3.7.0
|
| 12 |
+
seaborn>=0.12.0
|
| 13 |
+
|
| 14 |
+
# Metrics
|
| 15 |
+
scikit-learn>=1.3.0
|
| 16 |
+
|
| 17 |
+
# Progress bars
|
| 18 |
+
tqdm>=4.65.0
|
| 19 |
+
|
| 20 |
+
# Jupyter (optional, for notebooks)
|
| 21 |
+
jupyter>=1.0.0
|
| 22 |
+
ipywidgets>=8.0.0
|
| 23 |
+
|
| 24 |
+
# Configuration (optional)
|
| 25 |
+
pyyaml>=6.0
|
| 26 |
+
|
| 27 |
+
huggingface_hub
|