VIT / app.py
ZainabEman's picture
Upload 2 files
60d81c7 verified
# Generated from FineTuning.ipynb on 2025-11-04T05:32:04.267416Z
# NOTE: This is a direct export of code cells only. No code has been modified.
"""
Vision Transformer (ViT) Fine-tuning for Food-101 Classification
=================================================================
This notebook demonstrates:
1. Loading Food-101 dataset directly from Hugging Face
2. Data preprocessing and augmentation
3. Fine-tuning ViT (vit-base-patch16-224) using Hugging Face Transformers
4. Evaluation with metrics and visualizations
5. Gradio demo for inference
Author: AI Engineer
Model: vit-base-patch16-224
Dataset: Food-101 (101 food categories, 101,000 images)
"""
# ============================================================================
# CELL 1: Install Required Packages
# ============================================================================
!pip install -q transformers datasets accelerate evaluate torch torchvision pillow gradio scikit-learn matplotlib seaborn
print("✓ All packages installed successfully")
# ============================================================================
# CELL 2: Import Libraries
# ============================================================================
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from tqdm.auto import tqdm
import gradio as gr
import warnings
warnings.filterwarnings('ignore')
# Hugging Face imports
from transformers import (
ViTImageProcessor,
ViTForImageClassification,
TrainingArguments,
Trainer,
)
from datasets import load_dataset
import evaluate
# Set random seeds for reproducibility
import random
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
set_seed(42)
print("✓ All libraries imported successfully")
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
print(f"✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
# ===== Cell Separator =====
# ============================================================================
# CELL 3: Load Food-101 Dataset from Hugging Face
# ============================================================================
"""
Load Food-101 dataset directly from Hugging Face Datasets
- 101 food categories
- 101,000 images total
- 750 training images per class
- 250 test images per class
"""
print("📥 Loading Food-101 dataset from Hugging Face...")
print("This may take a few minutes on first run...")
# Load the dataset
dataset = load_dataset("food101", split=["train", "validation"])
train_dataset = dataset[0]
val_dataset = dataset[1]
print("\n✓ Dataset loaded successfully!")
print(f" Training samples: {len(train_dataset)}")
print(f" Validation samples: {len(val_dataset)}")
# Get class names
classes = train_dataset.features['label'].names
num_classes = len(classes)
print(f" Number of classes: {num_classes}")
print(f"\n Sample classes: {classes[:10]}")
# Create label mappings
id2label = {idx: cls for idx, cls in enumerate(classes)}
label2id = {cls: idx for idx, cls in enumerate(classes)}
# ===== Cell Separator =====
# ============================================================================
# CELL 4: Explore and Visualize Dataset
# ============================================================================
"""
Visualize sample images and analyze class distribution
"""
# Analyze class distribution
from collections import Counter
train_labels = train_dataset['label']
val_labels = val_dataset['label']
train_dist = Counter(train_labels)
val_dist = Counter(val_labels)
print("\n📊 Dataset Distribution:")
print(f" Images per class (train): {len(train_labels) / num_classes:.0f}")
print(f" Images per class (val): {len(val_labels) / num_classes:.0f}")
# Visualize class distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# Training distribution
axes[0].bar(range(num_classes), [train_dist[i] for i in range(num_classes)], alpha=0.7)
axes[0].set_xlabel('Class Index')
axes[0].set_ylabel('Number of Images')
axes[0].set_title('Training Set Distribution')
axes[0].grid(alpha=0.3)
# Validation distribution
axes[1].bar(range(num_classes), [val_dist[i] for i in range(num_classes)], alpha=0.7, color='orange')
axes[1].set_xlabel('Class Index')
axes[1].set_ylabel('Number of Images')
axes[1].set_title('Validation Set Distribution')
axes[1].grid(alpha=0.3)
plt.tight_layout()
plt.show()
# Display sample images
def show_sample_images(dataset, classes, n_samples=10):
"""Display random sample images from dataset"""
fig, axes = plt.subplots(2, 5, figsize=(18, 8))
axes = axes.ravel()
# Get random indices
indices = random.sample(range(len(dataset)), n_samples)
for idx, sample_idx in enumerate(indices):
sample = dataset[sample_idx]
image = sample['image']
label = sample['label']
class_name = classes[label]
axes[idx].imshow(image)
axes[idx].axis('off')
axes[idx].set_title(f"{class_name}", fontsize=10, wrap=True)
plt.suptitle('Sample Images from Food-101 Dataset', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()
print("\n🖼️ Displaying sample images...")
show_sample_images(train_dataset, classes)
# ===== Cell Separator =====
# ============================================================================
# CELL 5: Initialize ViT Model and Image Processor
# ============================================================================
"""
Load pre-trained ViT model and image processor from Hugging Face
Model: google/vit-base-patch16-224
- Base model: 86M parameters
- Patch size: 16x16
- Input size: 224x224
"""
MODEL_NAME = "google/vit-base-patch16-224"
print(f"\n🤖 Loading ViT model: {MODEL_NAME}")
# Load image processor (handles image preprocessing automatically)
image_processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
# Load model with custom number of labels for Food-101
model = ViTForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=num_classes,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True # Important: allows loading with different head
)
# Model statistics
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n✓ Model loaded successfully")
print(f" Total parameters: {total_params / 1e6:.1f}M")
print(f" Trainable parameters: {trainable_params / 1e6:.1f}M")
print(f" Image size: {image_processor.size}")
print(f" Number of output classes: {num_classes}")
# ===== Cell Separator =====
#============================================================================
# CELL 6: Data Preprocessing and Augmentation
# ============================================================================
"""
Define data preprocessing and augmentation pipelines:
- Training: Heavy augmentation (RandomResizedCrop, Flip, ColorJitter, etc.)
- Validation: Simple resize and normalization
"""
def transform_train(examples):
"""
Apply training augmentations and preprocessing
"""
# Data augmentation transforms
augmentation = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
])
# Apply augmentation to each image
images = [augmentation(img.convert("RGB")) for img in examples['image']]
# Apply ViT preprocessing
inputs = image_processor(images, return_tensors='pt')
inputs['labels'] = examples['label']
return inputs
def transform_val(examples):
"""
Apply validation preprocessing (no augmentation)
"""
# Simple resize for validation
images = [img.convert("RGB") for img in examples['image']]
# Apply ViT preprocessing
inputs = image_processor(images, return_tensors='pt')
inputs['labels'] = examples['label']
return inputs
print("🔄 Applying preprocessing transforms...")
# Apply transforms to datasets
train_dataset = train_dataset.with_transform(transform_train)
val_dataset = val_dataset.with_transform(transform_val)
print("✓ Transforms applied successfully")
print(f" Training augmentations: RandomResizedCrop, HorizontalFlip, Rotation, ColorJitter")
print(f" Validation preprocessing: Resize + Normalize")
# ===== Cell Separator =====
# ============================================================================
# CELL 7: Data Collator
# ============================================================================
"""
Create custom data collator to properly batch the preprocessed data
"""
def collate_fn(examples):
"""
Custom collator to handle batching of preprocessed images
"""
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
print("✓ Data collator configured")
# ===== Cell Separator =====
# ============================================================================
# CELL 8: Define Training Configuration
# ============================================================================
"""
Configure training parameters with best practices:
- Mixed precision (FP16) for memory efficiency
- Gradient accumulation for larger effective batch size
- Learning rate scheduling with warmup
- Early stopping based on accuracy
- Checkpointing best models only
"""
# Output directory for model checkpoints
OUTPUT_DIR = "./vit-food101-finetuned"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Training arguments optimized for Food-101
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
# Training hyperparameters
num_train_epochs=5, # 5 epochs typically sufficient
per_device_train_batch_size=32, # Adjust based on GPU memory (16 for 12GB, 32 for 24GB)
per_device_eval_batch_size=64, # Larger batch for eval (no gradients)
gradient_accumulation_steps=2, # Effective batch = 32 * 2 = 64
# Optimization
learning_rate=3e-4, # Higher LR for fine-tuning
weight_decay=0.01, # L2 regularization
warmup_ratio=0.1, # 10% warmup steps
lr_scheduler_type="cosine", # Cosine decay schedule
# Mixed precision training (crucial for memory efficiency)
fp16=torch.cuda.is_available(), # Enable on GPU
fp16_full_eval=True, # Also use FP16 for evaluation
# Evaluation and logging
eval_strategy="epoch", # Evaluate after each epoch
save_strategy="epoch", # Save after each epoch
logging_steps=100, # Log every 100 steps
logging_dir=f"{OUTPUT_DIR}/logs",
# Model checkpointing
load_best_model_at_end=True, # Load best model at end
metric_for_best_model="accuracy", # Use accuracy for model selection
greater_is_better=True,
save_total_limit=2, # Keep only best 2 checkpoints
# Performance optimizations
dataloader_num_workers=4, # Parallel data loading
dataloader_pin_memory=True, # Pin memory for faster GPU transfer
remove_unused_columns=False, # Keep all columns
# Reproducibility
seed=42,
data_seed=42,
# Optional: Push to Hugging Face Hub
push_to_hub=False,
# Report to tensorboard
report_to=["tensorboard"],
)
# Print configuration summary
print("\n⚙️ Training Configuration:")
print("=" * 70)
print(f" Epochs: {training_args.num_train_epochs}")
print(f" Batch size (train): {training_args.per_device_train_batch_size}")
print(f" Batch size (eval): {training_args.per_device_eval_batch_size}")
print(f" Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f" Learning rate: {training_args.learning_rate}")
print(f" Weight decay: {training_args.weight_decay}")
print(f" Warmup ratio: {training_args.warmup_ratio}")
print(f" LR scheduler: {training_args.lr_scheduler_type}")
print(f" FP16 training: {training_args.fp16}")
print(f" Total optimization steps: {(len(train_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)) * training_args.num_train_epochs}")
print("=" * 70)
# ===== Cell Separator =====
# ============================================================================
# CELL 9: Define Evaluation Metrics
# ============================================================================
"""
Setup comprehensive evaluation metrics:
- Accuracy (primary metric)
- Precision, Recall, F1 (computed separately)
"""
# Load accuracy metric from Hugging Face
accuracy_metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
"""
Compute metrics during evaluation
Args:
eval_pred: EvalPrediction object containing predictions and labels
Returns:
dict: Dictionary of computed metrics
"""
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
# Compute accuracy
accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
return accuracy
print("✓ Evaluation metrics configured")
print(" Primary metric: Accuracy")
# ============================================================================
# CELL 10: Initialize Trainer and Start Training
# ============================================================================
"""
Create Trainer object and start fine-tuning:
- Handles training loop automatically
- Performs validation after each epoch
- Saves best model based on accuracy
- Uses mixed precision for efficiency
"""
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=collate_fn,
compute_metrics=compute_metrics,
tokenizer=image_processor, # Save image processor with model
)
print("\n🚀 Starting training...")
print("=" * 70)
print("Training Progress:")
print("=" * 70)
# Train the model
train_result = trainer.train()
print("\n" + "=" * 70)
print("✓ Training completed successfully!")
print("=" * 70)
print(f" Training runtime: {train_result.metrics['train_runtime']:.2f} seconds ({train_result.metrics['train_runtime']/60:.2f} minutes)")
print(f" Training samples/second: {train_result.metrics['train_samples_per_second']:.2f}")
print(f" Training loss: {train_result.metrics['train_loss']:.4f}")
# Save the final model
trainer.save_model(OUTPUT_DIR)
image_processor.save_pretrained(OUTPUT_DIR)
print(f"\n✓ Model and processor saved to: {OUTPUT_DIR}")
# ===== Cell Separator =====
# ============================================================================
# CELL 11: Comprehensive Model Evaluation
# ============================================================================
"""
Evaluate model on validation set:
- Calculate accuracy
- Generate predictions for detailed analysis
- Prepare for confusion matrix and classification report
"""
print("\n📊 Evaluating model on validation set...")
print("=" * 70)
# Evaluate on validation set
eval_results = trainer.evaluate()
print("\nVALIDATION RESULTS:")
print("=" * 70)
for key, value in eval_results.items():
if isinstance(value, float):
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")
print("=" * 70)
# Get detailed predictions
print("\n🔮 Generating predictions for detailed analysis...")
predictions_output = trainer.predict(val_dataset)
pred_logits = predictions_output.predictions
true_labels = predictions_output.label_ids
pred_labels = np.argmax(pred_logits, axis=1)
# Calculate final accuracy
final_accuracy = accuracy_score(true_labels, pred_labels)
print(f"\n✓ Final Validation Accuracy: {final_accuracy * 100:.2f}%")
# Calculate confidence scores
pred_probs = np.exp(pred_logits) / np.exp(pred_logits).sum(axis=1, keepdims=True)
pred_confidences = pred_probs.max(axis=1)
avg_confidence = pred_confidences.mean()
print(f"✓ Average Prediction Confidence: {avg_confidence * 100:.2f}%")
# ===== Cell Separator =====
# ============================================================================
# CELL 12: Generate and Visualize Confusion Matrix
# ============================================================================
"""
Create confusion matrix to understand model performance across all 101 classes
"""
print("\n📊 Generating confusion matrix...")
# Calculate confusion matrix
cm = confusion_matrix(true_labels, pred_labels)
# Plot confusion matrix (heatmap without annotations due to 101 classes)
plt.figure(figsize=(22, 20))
sns.heatmap(
cm,
annot=False, # Too many classes for annotations
fmt='d',
cmap='Blues',
square=True,
cbar_kws={'label': 'Number of Predictions', 'shrink': 0.8},
linewidths=0
)
plt.title('Confusion Matrix - Food-101 Classification\n(101 Classes)', fontsize=16, pad=20)
plt.ylabel('True Label', fontsize=14)
plt.xlabel('Predicted Label', fontsize=14)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"✓ Confusion matrix saved to {OUTPUT_DIR}/confusion_matrix.png")
# Calculate per-class accuracy
class_accuracy = cm.diagonal() / cm.sum(axis=1)
class_acc_df = pd.DataFrame({
'Class': classes,
'Accuracy': class_accuracy * 100,
'Correct': cm.diagonal(),
'Total': cm.sum(axis=1)
}).sort_values('Accuracy', ascending=False)
# Display statistics
print("\n📈 Per-Class Performance Statistics:")
print(f" Best class accuracy: {class_accuracy.max() * 100:.2f}%")
print(f" Worst class accuracy: {class_accuracy.min() * 100:.2f}%")
print(f" Mean class accuracy: {class_accuracy.mean() * 100:.2f}%")
print(f" Median class accuracy: {np.median(class_accuracy) * 100:.2f}%")
# Display top and bottom performers
print("\n🏆 Top 10 Best Performing Classes:")
print(class_acc_df.head(10).to_string(index=False))
print("\n⚠️ Top 10 Worst Performing Classes:")
print(class_acc_df.tail(10).to_string(index=False))
# Save per-class accuracy
class_acc_df.to_csv(f"{OUTPUT_DIR}/per_class_accuracy.csv", index=False)
print(f"\n✓ Per-class accuracy saved to {OUTPUT_DIR}/per_class_accuracy.csv")
# Visualize per-class accuracy distribution
plt.figure(figsize=(15, 6))
plt.subplot(1, 2, 1)
plt.bar(range(len(class_accuracy)), sorted(class_accuracy * 100)) # Fix: Multiply by 100 before sorting
plt.xlabel('Class Rank')
plt.ylabel('Accuracy (%)')
plt.title('Per-Class Accuracy Distribution (Sorted)')
plt.grid(alpha=0.3)
plt.subplot(1, 2, 2)
plt.hist(class_accuracy * 100, bins=30, edgecolor='black', alpha=0.7)
plt.xlabel('Accuracy (%)')
plt.ylabel('Number of Classes')
plt.title('Accuracy Distribution Histogram')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/accuracy_distribution.png", dpi=150, bbox_inches='tight')
plt.show()
# ===== Cell Separator =====
# ============================================================================
# CELL 13: Detailed Classification Report
# ============================================================================
"""
Generate comprehensive classification report with:
- Precision, Recall, F1-score per class
- Support (number of samples) per class
- Macro and weighted averages
"""
print("\n📋 Generating classification report...")
# Generate classification report
report = classification_report(
true_labels,
pred_labels,
target_names=classes,
digits=4,
output_dict=False
)
print("\n" + "=" * 70)
print("CLASSIFICATION REPORT")
print("=" * 70)
print(report)
# Save report to file
with open(f"{OUTPUT_DIR}/classification_report.txt", 'w') as f:
f.write("Food-101 Classification Report\n")
f.write("=" * 70 + "\n\n")
f.write(f"Model: {MODEL_NAME}\n")
f.write(f"Dataset: Food-101 (101 classes)\n")
f.write(f"Validation Samples: {len(val_dataset)}\n")
f.write(f"Overall Accuracy: {final_accuracy * 100:.2f}%\n\n")
f.write("=" * 70 + "\n")
f.write(report)
print(f"\n✓ Classification report saved to {OUTPUT_DIR}/classification_report.txt")
# Get report as dictionary for analysis
report_dict = classification_report(
true_labels,
pred_labels,
target_names=classes,
digits=4,
output_dict=True
)
# Extract macro and weighted averages
print("\n📊 Summary Metrics:")
print(f" Macro Avg Precision: {report_dict['macro avg']['precision'] * 100:.2f}%")
print(f" Macro Avg Recall: {report_dict['macro avg']['recall'] * 100:.2f}%")
print(f" Macro Avg F1-Score: {report_dict['macro avg']['f1-score'] * 100:.2f}%")
print(f" Weighted Avg Precision: {report_dict['weighted avg']['precision'] * 100:.2f}%")
print(f" Weighted Avg Recall: {report_dict['weighted avg']['recall'] * 100:.2f}%")
print(f" Weighted Avg F1-Score: {report_dict['weighted avg']['f1-score'] * 100:.2f}%")
# ============================================================================
# CELL 14: Visualize Sample Predictions
# ============================================================================
"""
Display sample predictions with:
- Original images
- True labels
- Predicted labels
- Confidence scores
- Color coding (green=correct, red=incorrect)
"""
def visualize_predictions(dataset, pred_labels, true_labels, pred_probs, classes, n_samples=20):
"""
Visualize random sample predictions
"""
# Random sample indices
indices = random.sample(range(len(dataset)), n_samples)
fig, axes = plt.subplots(4, 5, figsize=(20, 16))
axes = axes.ravel()
for idx, sample_idx in enumerate(indices):
# Get sample
sample = dataset[sample_idx]
# Get original image (before transforms)
# Need to access raw dataset
original_dataset = load_dataset("food101", split="validation")
image = original_dataset[sample_idx]['image']
true_label_idx = true_labels[sample_idx]
pred_label_idx = pred_labels[sample_idx]
confidence = pred_probs[sample_idx, pred_label_idx]
true_class = classes[true_label_idx]
pred_class = classes[pred_label_idx]
# Plot image
axes[idx].imshow(image)
axes[idx].axis('off')
# Color code: green for correct, red for incorrect
is_correct = (true_label_idx == pred_label_idx)
color = 'green' if is_correct else 'red'
marker = '✓' if is_correct else '✗'
title = f"{marker} True: {true_class}\nPred: {pred_class}\nConf: {confidence:.3f}"
axes[idx].set_title(title, fontsize=9, color=color, weight='bold')
plt.suptitle('Sample Predictions (Green=Correct ✓, Red=Incorrect ✗)',
fontsize=16, y=0.995, weight='bold')
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/sample_predictions.png", dpi=150, bbox_inches='tight')
plt.show()
print("\n🖼️ Visualizing sample predictions...")
visualize_predictions(val_dataset, pred_labels, true_labels, pred_probs, classes)
print(f"✓ Sample predictions saved to {OUTPUT_DIR}/sample_predictions.png")
# Analyze correct vs incorrect predictions
correct_mask = (pred_labels == true_labels)
correct_confidences = pred_confidences[correct_mask]
incorrect_confidences = pred_confidences[~correct_mask]
print(f"\n📊 Confidence Analysis:")
print(f" Correct predictions: {correct_mask.sum()} ({correct_mask.sum()/len(correct_mask)*100:.2f}%)")
print(f" Incorrect predictions: {(~correct_mask).sum()} ({(~correct_mask).sum()/len(correct_mask)*100:.2f}%)")
print(f" Avg confidence (correct): {correct_confidences.mean()*100:.2f}%")
print(f" Avg confidence (incorrect): {incorrect_confidences.mean()*100:.2f}%")
# ===== Cell Separator =====
# ============================================================================
# CELL 15: Create Gradio Demo Interface
# ============================================================================
"""
Interactive Gradio demo for food classification:
- Upload food image
- Get real-time classification
- Display top-5 predictions with confidence scores
- Beautiful UI with examples
"""
print("\n🎨 Creating Gradio demo interface...")
# Load the fine-tuned model for inference
inference_model = ViTForImageClassification.from_pretrained(OUTPUT_DIR)
inference_processor = ViTImageProcessor.from_pretrained(OUTPUT_DIR)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_model = inference_model.to(device)
inference_model.eval()
print(f"✓ Model loaded on: {device}")
def classify_food_image(image):
"""
Classify uploaded food image
Args:
image: PIL Image or numpy array
Returns:
dict: Top-5 predictions with confidence scores
"""
# Ensure image is PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(image).convert('RGB')
else:
image = image.convert('RGB')
# Preprocess image
inputs = inference_processor(images=image, return_tensors="pt").to(device)
# Get predictions
with torch.no_grad():
outputs = inference_model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Get top-5 predictions
top5_probs, top5_indices = torch.topk(probs, k=5)
# Format results for Gradio
results = {}
for prob, idx in zip(top5_probs, top5_indices):
class_name = id2label[idx.item()]
# Format class name (replace underscores with spaces, capitalize)
formatted_name = class_name.replace('_', ' ').title()
confidence = prob.item()
results[formatted_name] = float(confidence)
return results
# Create example images from validation set
example_images = []
original_val_dataset = load_dataset("food101", split="validation")
example_indices = random.sample(range(len(original_val_dataset)), 5)
for idx in example_indices:
img = original_val_dataset[idx]['image']
example_images.append(img)
# Create Gradio interface with custom styling
demo = gr.Interface(
fn=classify_food_image,
inputs=gr.Image(type="pil", label="📸 Upload Food Image"),
outputs=gr.Label(num_top_classes=5, label="🎯 Predictions"),
title="🍔 Food-101 Classifier with Vision Transformer",
description="""
### AI-Powered Food Recognition System
Upload any food image and get instant classification using state-of-the-art Vision Transformer (ViT) technology!
**🔬 Model Details:**
- Architecture: Vision Transformer (ViT-base-patch16-224)
- Dataset: Food-101 (101 food categories)
- Training: Fine-tuned on 75,750 images
- Validation Accuracy: {:.2f}%
**📊 Supported Categories:** 101 different food types including pizza, sushi, hamburger, ice cream, and more!
**✨ Features:**
- Real-time classification
- Top-5 predictions with confidence scores
- High accuracy transformer-based model
- Supports various image formats
""".format(final_accuracy * 100),
article="""
### 📚 About This Model
This model uses **Vision Transformer (ViT)** architecture, which applies the transformer architecture
(originally designed for NLP) to computer vision tasks. The model divides images into patches and
processes them similar to how transformers process word tokens in text.
**Training Details:**
- Pre-trained weights: Google's ViT-base-patch16-224
- Fine-tuning: 5 epochs on Food-101 dataset
- Optimization: AdamW with cosine learning rate schedule
- Augmentation: Random crops, flips, rotations, and color jitter
**Performance Metrics:**
- Validation Accuracy: {:.2f}%
- Average Confidence: {:.2f}%
- Total Parameters: 86M
**💡 Tips for Best Results:**
- Use clear, well-lit images
- Center the food item in the frame
- Avoid heavily filtered or edited images
- Single food item works best
---
*Developed by AI Engineer | Powered by Hugging Face Transformers*
""".format(final_accuracy * 100, avg_confidence * 100),
examples=example_images,
theme=gr.themes.Soft(
primary_hue="orange",
secondary_hue="blue",
),
allow_flagging="never",
css="""
.gradio-container {
font-family: 'Arial', sans-serif;
}
footer {
display: none !important;
}
"""
)
# Launch the demo
print("\n🚀 Launching Gradio interface...")
print("=" * 70)
demo.launch(
share=True, # Create public link
debug=True, # Show detailed errors
show_error=True # Display errors in UI
)
print("\n✓ Gradio demo launched successfully!")
print(" A public URL will be displayed above (valid for 72 hours)")
# ============================================================================
# CELL 16: Save Complete Model Summary and Metrics
# ============================================================================
"""
Save comprehensive training summary including:
- Model configuration
- Training hyperparameters
- All evaluation metrics
- Class mappings
- Training history
"""
print("\n💾 Saving model summary and metrics...")
# Create comprehensive summary
model_summary = {
"model_info": {
"model_name": MODEL_NAME,
"model_type": "Vision Transformer (ViT)",
"architecture": "vit-base-patch16-224",
"total_parameters": f"{total_params / 1e6:.1f}M",
"trainable_parameters": f"{trainable_params / 1e6:.1f}M",
"patch_size": "16x16",
"image_size": "224x224",
"hidden_size": 768,
"num_attention_heads": 12,
"num_hidden_layers": 12
},
"dataset_info": {
"dataset_name": "Food-101",
"num_classes": num_classes,
"train_samples": len(train_dataset),
"val_samples": len(val_dataset),
"total_samples": len(train_dataset) + len(val_dataset),
"train_images_per_class": len(train_dataset) // num_classes,
"val_images_per_class": len(val_dataset) // num_classes
},
"training_config": {
"num_epochs": training_args.num_train_epochs,
"batch_size_train": training_args.per_device_train_batch_size,
"batch_size_eval": training_args.per_device_eval_batch_size,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"effective_batch_size": training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
"learning_rate": training_args.learning_rate,
"weight_decay": training_args.weight_decay,
"warmup_ratio": training_args.warmup_ratio,
"lr_scheduler": training_args.lr_scheduler_type,
"optimizer": "AdamW",
"fp16_training": training_args.fp16,
"seed": training_args.seed
},
"data_augmentation": {
"train": [
"RandomResizedCrop(224, scale=(0.8, 1.0))",
"RandomHorizontalFlip(p=0.5)",
"RandomRotation(15)",
"ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)",
"RandomAffine(translate=(0.1, 0.1))"
],
"val": [
"Resize(224)",
"Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])"
]
},
"performance_metrics": {
"validation_accuracy": float(final_accuracy),
"validation_accuracy_percent": f"{final_accuracy * 100:.2f}%",
"average_confidence": float(avg_confidence),
"average_confidence_percent": f"{avg_confidence * 100:.2f}%",
"correct_predictions": int(correct_mask.sum()),
"incorrect_predictions": int((~correct_mask).sum()),
"macro_avg_precision": float(report_dict['macro avg']['precision']),
"macro_avg_recall": float(report_dict['macro avg']['recall']),
"macro_avg_f1": float(report_dict['macro avg']['f1-score']),
"weighted_avg_precision": float(report_dict['weighted avg']['precision']),
"weighted_avg_recall": float(report_dict['weighted avg']['recall']),
"weighted_avg_f1": float(report_dict['weighted avg']['f1-score']),
"best_class_accuracy": float(class_accuracy.max()),
"worst_class_accuracy": float(class_accuracy.min()),
"mean_class_accuracy": float(class_accuracy.mean()),
"median_class_accuracy": float(np.median(class_accuracy))
},
"training_results": {
"training_runtime_seconds": float(train_result.metrics['train_runtime']),
"training_runtime_minutes": float(train_result.metrics['train_runtime'] / 60),
"training_samples_per_second": float(train_result.metrics['train_samples_per_second']),
"final_train_loss": float(train_result.metrics['train_loss'])
},
"class_mappings": {
"id2label": id2label,
"label2id": label2id,
"class_list": classes
},
"top_performing_classes": [
{
"class": row['Class'],
"accuracy": f"{row['Accuracy']:.2f}%",
"correct": int(row['Correct']),
"total": int(row['Total'])
}
for _, row in class_acc_df.head(10).iterrows()
],
"bottom_performing_classes": [
{
"class": row['Class'],
"accuracy": f"{row['Accuracy']:.2f}%",
"correct": int(row['Correct']),
"total": int(row['Total'])
}
for _, row in class_acc_df.tail(10).iterrows()
],
"output_files": {
"model_checkpoint": f"{OUTPUT_DIR}/",
"confusion_matrix": f"{OUTPUT_DIR}/confusion_matrix.png",
"accuracy_distribution": f"{OUTPUT_DIR}/accuracy_distribution.png",
"sample_predictions": f"{OUTPUT_DIR}/sample_predictions.png",
"classification_report": f"{OUTPUT_DIR}/classification_report.txt",
"per_class_accuracy": f"{OUTPUT_DIR}/per_class_accuracy.csv",
"model_summary": f"{OUTPUT_DIR}/model_summary.json"
}
}
# Save summary as JSON
with open(f"{OUTPUT_DIR}/model_summary.json", 'w') as f:
json.dump(model_summary, f, indent=2)
print(f"✓ Model summary saved to {OUTPUT_DIR}/model_summary.json")
# Also save as readable text file
with open(f"{OUTPUT_DIR}/model_summary.txt", 'w') as f:
f.write("=" * 80 + "\n")
f.write("FOOD-101 VISION TRANSFORMER (ViT) CLASSIFICATION MODEL SUMMARY\n")
f.write("=" * 80 + "\n\n")
f.write("MODEL INFORMATION\n")
f.write("-" * 80 + "\n")
for key, value in model_summary["model_info"].items():
f.write(f"{key.replace('_', ' ').title()}: {value}\n")
f.write("\n" + "=" * 80 + "\n")
f.write("DATASET INFORMATION\n")
f.write("-" * 80 + "\n")
for key, value in model_summary["dataset_info"].items():
f.write(f"{key.replace('_', ' ').title()}: {value}\n")
f.write("\n" + "=" * 80 + "\n")
f.write("TRAINING CONFIGURATION\n")
f.write("-" * 80 + "\n")
for key, value in model_summary["training_config"].items():
f.write(f"{key.replace('_', ' ').title()}: {value}\n")
f.write("\n" + "=" * 80 + "\n")
f.write("PERFORMANCE METRICS\n")
f.write("-" * 80 + "\n")
for key, value in model_summary["performance_metrics"].items():
f.write(f"{key.replace('_', ' ').title()}: {value}\n")
f.write("\n" + "=" * 80 + "\n")
f.write("TRAINING RESULTS\n")
f.write("-" * 80 + "\n")
for key, value in model_summary["training_results"].items():
f.write(f"{key.replace('_', ' ').title()}: {value}\n")
f.write("\n" + "=" * 80 + "\n")
f.write("TOP 10 PERFORMING CLASSES\n")
f.write("-" * 80 + "\n")
for i, cls_info in enumerate(model_summary["top_performing_classes"], 1):
f.write(f"{i}. {cls_info['class']}: {cls_info['accuracy']} ({cls_info['correct']}/{cls_info['total']})\n")
f.write("\n" + "=" * 80 + "\n")
f.write("BOTTOM 10 PERFORMING CLASSES\n")
f.write("-" * 80 + "\n")
for i, cls_info in enumerate(model_summary["bottom_performing_classes"], 1):
f.write(f"{i}. {cls_info['class']}: {cls_info['accuracy']} ({cls_info['correct']}/{cls_info['total']})\n")
print(f"✓ Readable summary saved to {OUTPUT_DIR}/model_summary.txt")
# ===== Cell Separator =====
# ============================================================================
# CELL 17: Inference Helper Functions
# ============================================================================
"""
Utility functions for making predictions on new images
"""
def predict_single_image(image_path, model_path=None, top_k=5):
"""
Predict food class for a single image from file path
Args:
image_path: Path to image file
model_path: Path to fine-tuned model directory
top_k: Number of top predictions to return
Returns:
dict: Top-k predictions with class names and confidence scores
"""
if model_path is None: # Use the globally defined OUTPUT_DIR if not specified
model_path = OUTPUT_DIR
# Load model and processor
model = ViTForImageClassification.from_pretrained(model_path)
processor = ViTImageProcessor.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Load and preprocess image
image = Image.open(image_path).convert('RGB')
inputs = processor(images=image, return_tensors="pt").to(device)
# Get predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Get top-k predictions
topk_probs, topk_indices = torch.topk(probs, k=top_k)
# Format results
results = []
for prob, idx in zip(topk_probs, topk_indices):
class_name = model.config.id2label[idx.item()]
confidence = prob.item()
results.append({
'class': class_name,
'confidence': confidence,
'confidence_percent': f"{confidence * 100:.2f}%"
})
return results
def batch_predict(image_paths, model_path=None, batch_size=32):
"""
Predict food classes for multiple images efficiently
Args:
image_paths: List of image file paths
model_path: Path to fine-tuned model directory
batch_size: Batch size for processing
Returns:
list: Predictions for each image
"""
if model_path is None: # Use the globally defined OUTPUT_DIR if not specified
model_path = OUTPUT_DIR
# Load model and processor
model = ViTForImageClassification.from_pretrained(model_path)
processor = ViTImageProcessor.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
all_predictions = []
# Process in batches
for i in tqdm(range(0, len(image_paths), batch_size), desc="Processing batches"):
batch_paths = image_paths[i:i + batch_size]
# Load images
images = [Image.open(path).convert('RGB') for path in batch_paths]
# Preprocess
inputs = processor(images=images, return_tensors="pt").to(device)
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred_indices = logits.argmax(dim=-1)
# Store results
for pred_idx in pred_indices:
class_name = model.config.id2label[pred_idx.item()]
all_predictions.append(class_name)
return all_predictions
# Example usage
print("\n📝 Inference Helper Functions:")
print(" 1. predict_single_image(image_path) - Predict single image")
print(" 2. batch_predict(image_paths) - Batch prediction for multiple images")
print("\n Example:")
print(" >>> results = predict_single_image('path/to/food.jpg')")
print(" >>> print(results)")
# ============================================================================
# CELL 18: Final Report and Summary
# ============================================================================
"""
Generate final comprehensive report
"""
print("\n" + "=" * 80)
print("🎉 TRAINING PIPELINE COMPLETED SUCCESSFULLY!")
print("=" * 80)
print("\n📊 FINAL RESULTS SUMMARY:")
print("-" * 80)
print(f"✓ Model: {MODEL_NAME}")
print(f"✓ Dataset: Food-101 (101 classes)")
print(f"✓ Training Samples: {len(train_dataset):,}")
print(f"✓ Validation Samples: {len(val_dataset):,}")
print(f"✓ Training Epochs: {training_args.num_train_epochs}")
print(f"✓ Training Time: {train_result.metrics['train_runtime']/60:.2f} minutes")
print(f"✓ Final Validation Accuracy: {final_accuracy * 100:.2f}%")
print(f"✓ Average Prediction Confidence: {avg_confidence * 100:.2f}%")
print(f"✓ Best Class Accuracy: {class_accuracy.max() * 100:.2f}%")
print(f"✓ Worst Class Accuracy: {class_accuracy.min() * 100:.2f}%")
print(f"✓ Macro Avg F1-Score: {report_dict['macro avg']['f1-score'] * 100:.2f}%")
print("-" * 80)
print("\n📁 OUTPUT FILES GENERATED:")
print("-" * 80)
print(f"✓ Fine-tuned Model: {OUTPUT_DIR}/")
print(f"✓ Confusion Matrix: {OUTPUT_DIR}/confusion_matrix.png")
print(f"✓ Accuracy Distribution: {OUTPUT_DIR}/accuracy_distribution.png")
print(f"✓ Sample Predictions: {OUTPUT_DIR}/sample_predictions.png")
print(f"✓ Classification Report: {OUTPUT_DIR}/classification_report.txt")
print(f"✓ Per-Class Accuracy: {OUTPUT_DIR}/per_class_accuracy.csv")
print(f"✓ Model Summary (JSON): {OUTPUT_DIR}/model_summary.json")
print(f"✓ Model Summary (Text): {OUTPUT_DIR}/model_summary.txt")
print(f"✓ Training Logs: {OUTPUT_DIR}/logs/")
print("-" * 80)
print("\n✅ DELIVERABLES CHECKLIST:")
print("-" * 80)
print("✓ Data preprocessing and augmentation pipeline")
print("✓ Fine-tuning pipeline with training and validation")
print("✓ Evaluation using accuracy, confusion matrix, and classification report")
print("✓ Sample predictions with visualizations")
print("✓ Gradio demo for image upload and prediction")
print("✓ Comprehensive documentation and metrics")
print("-" * 80)
print("\n🚀 NEXT STEPS:")
print("-" * 80)
print("1. Use the Gradio interface above to test with your own images")
print("2. Download the model from: ./vit-food101-finetuned/")
print("3. Review classification report for detailed metrics")
print("4. Check per-class accuracy to identify improvement areas")
print("5. Use inference functions for batch predictions")
print("-" * 80)
print("\n💡 OPTIMIZATION TIPS:")
print("-" * 80)
print("✓ For better accuracy: Train for more epochs (7-10)")
print("✓ For faster training: Increase batch size if GPU memory allows")
print("✓ For memory issues: Reduce batch size or enable gradient checkpointing")
print("✓ For production: Use model quantization or distillation")
print("✓ For deployment: Export to ONNX format for inference optimization")
print("-" * 80)
print("\n📚 MODEL USAGE:")
print("-" * 80)
print("Load the model in your own code:")
print("""
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
# Load model
model = ViTForImageClassification.from_pretrained('./vit-food101-finetuned')
processor = ViTImageProcessor.from_pretrained('./vit-food101-finetuned')
# Predict
image = Image.open('food.jpg')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
predicted_class = outputs.logits.argmax(-1).item()
print(f"Predicted: {model.config.id2label[predicted_class]}")
""")
print("-" * 80)
print("\n🎓 KEY LEARNINGS:")
print("-" * 80)
print("✓ Vision Transformers can achieve excellent performance on image classification")
print("✓ Transfer learning from pre-trained models is highly effective")
print("✓ Data augmentation is crucial for preventing overfitting")
print("✓ Mixed precision training (FP16) significantly reduces memory usage")
print("✓ Proper evaluation metrics reveal per-class performance insights")
print("-" * 80)
print("\n" + "=" * 80)
print("Thank you for using this Vision Transformer Food Classification Pipeline!")
print("=" * 80 + "\n")
# ============================================================================
# END OF NOTEBOOK
# ========================
# ===== Cell Separator =====
# ===== Cell Separator =====