Spaces:
Running
Running
Commit
·
32d4a86
1
Parent(s):
400b4a4
Deploy image captioner
Browse files- README.md +35 -5
- app/__init__.py +102 -0
- app/__pycache__/__init__.cpython-314.pyc +0 -0
- app/__pycache__/config.cpython-314.pyc +0 -0
- app/__pycache__/routes.cpython-314.pyc +0 -0
- app/config.py +31 -0
- app/routes.py +194 -0
- app/utils/__init__.py +2 -0
- app/utils/__pycache__/__init__.cpython-314.pyc +0 -0
- app/utils/__pycache__/model_cache.cpython-314.pyc +0 -0
- app/utils/model_cache.py +343 -0
- hf_space_Dockerfile +50 -0
- hf_space_app.py +21 -0
- hf_space_requirements.txt +13 -0
- scripts/download_model.py +112 -0
- scripts/efficient_caption.py +82 -0
- scripts/optimize_models.py +323 -0
- scripts/resnet_caption.py +39 -0
- static/css/custom.css +144 -0
- static/js/main.js +85 -0
- templates/index.html +71 -0
- training/__pycache__/efficient_train.cpython-314.pyc +0 -0
- training/__pycache__/resnet_train.cpython-314.pyc +0 -0
- training/efficient_train.py +499 -0
- training/hyperparameter_tuning.py +197 -0
- training/resnet_train.py +497 -0
- training/train_advanced.py +306 -0
README.md
CHANGED
|
@@ -1,11 +1,41 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
|
|
|
|
|
|
| 7 |
pinned: false
|
| 8 |
-
license:
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Image Caption Generator
|
| 3 |
+
emoji: 🖼️
|
| 4 |
+
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
+
sdk_version: latest
|
| 8 |
+
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Image Caption Generator
|
| 14 |
+
|
| 15 |
+
Generate captions for images using an optimized EfficientNet-B3 model.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- ✅ EfficientNet-B3 model for high-quality captions
|
| 20 |
+
- ✅ Optimized quantized model (~245MB)
|
| 21 |
+
- ✅ Fast inference
|
| 22 |
+
- ✅ Simple web interface
|
| 23 |
+
|
| 24 |
+
## How to Use
|
| 25 |
+
|
| 26 |
+
1. Upload an image (PNG, JPG, JPEG)
|
| 27 |
+
2. Click "Generate Caption"
|
| 28 |
+
3. Get your caption!
|
| 29 |
+
|
| 30 |
+
## Model
|
| 31 |
+
|
| 32 |
+
- **Architecture:** EfficientNet-B3
|
| 33 |
+
- **Optimization:** INT8 Quantization
|
| 34 |
+
- **Size:** ~245MB
|
| 35 |
+
|
| 36 |
+
## Technical Details
|
| 37 |
+
|
| 38 |
+
- Built with PyTorch and Transformers
|
| 39 |
+
- Uses GPT-2 tokenizer
|
| 40 |
+
- Optimized for production deployment
|
| 41 |
+
|
app/__init__.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Image Caption Generator - Flask Application
|
| 3 |
+
Production-ready application with model caching and security.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from flask import Flask
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
# Configure logging
|
| 11 |
+
logging.basicConfig(
|
| 12 |
+
level=logging.INFO,
|
| 13 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 14 |
+
)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_app(config=None):
|
| 19 |
+
"""
|
| 20 |
+
Application factory pattern.
|
| 21 |
+
Creates and configures the Flask application.
|
| 22 |
+
"""
|
| 23 |
+
# Get base directory (project root)
|
| 24 |
+
import os
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
base_dir = Path(__file__).resolve().parent.parent
|
| 27 |
+
|
| 28 |
+
app = Flask(__name__,
|
| 29 |
+
template_folder=str(base_dir / 'templates'),
|
| 30 |
+
static_folder=str(base_dir / 'static'))
|
| 31 |
+
|
| 32 |
+
# Load configuration
|
| 33 |
+
app.secret_key = os.environ.get("SESSION_SECRET")
|
| 34 |
+
if not app.secret_key or app.secret_key == "default-secret-key":
|
| 35 |
+
if os.environ.get("FLASK_ENV") == "production":
|
| 36 |
+
raise ValueError("SESSION_SECRET must be set in production environment!")
|
| 37 |
+
else:
|
| 38 |
+
logger.warning("Using default secret key. Set SESSION_SECRET in production!")
|
| 39 |
+
app.secret_key = "default-secret-key"
|
| 40 |
+
|
| 41 |
+
# Configuration
|
| 42 |
+
app.config['UPLOAD_FOLDER'] = os.environ.get('UPLOAD_FOLDER', 'uploads')
|
| 43 |
+
app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get('MAX_FILE_SIZE', 10 * 1024 * 1024))
|
| 44 |
+
app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg'}
|
| 45 |
+
|
| 46 |
+
# Create uploads directory
|
| 47 |
+
if not os.path.exists(app.config['UPLOAD_FOLDER']):
|
| 48 |
+
os.makedirs(app.config['UPLOAD_FOLDER'])
|
| 49 |
+
|
| 50 |
+
# Register blueprints/routes
|
| 51 |
+
from app.routes import bp
|
| 52 |
+
app.register_blueprint(bp)
|
| 53 |
+
|
| 54 |
+
# Download model if needed (before loading)
|
| 55 |
+
# Try HF Hub first, then download URL
|
| 56 |
+
if os.environ.get("FLASK_ENV") == "production" or os.environ.get("LOAD_MODELS", "true").lower() == "true":
|
| 57 |
+
try:
|
| 58 |
+
# Try downloading from Hugging Face Hub first
|
| 59 |
+
model_repo = os.environ.get("HF_MODEL_REPO")
|
| 60 |
+
if model_repo:
|
| 61 |
+
try:
|
| 62 |
+
from huggingface_hub import hf_hub_download
|
| 63 |
+
logger.info(f"Downloading model from HF Hub: {model_repo}")
|
| 64 |
+
model_path = hf_hub_download(
|
| 65 |
+
repo_id=model_repo,
|
| 66 |
+
filename="efficientnet_efficient_best_model_quantized.pth",
|
| 67 |
+
cache_dir=str(base_dir / "models" / "optimized_models")
|
| 68 |
+
)
|
| 69 |
+
logger.info(f"Model downloaded from HF Hub: {model_path}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.warning(f"Could not download from HF Hub: {e}. Trying download URL...")
|
| 72 |
+
import sys
|
| 73 |
+
sys.path.insert(0, str(base_dir))
|
| 74 |
+
from scripts.download_model import download_efficientnet_model
|
| 75 |
+
download_efficientnet_model()
|
| 76 |
+
else:
|
| 77 |
+
# Fallback to download URL method
|
| 78 |
+
import sys
|
| 79 |
+
sys.path.insert(0, str(base_dir))
|
| 80 |
+
from scripts.download_model import download_efficientnet_model
|
| 81 |
+
download_efficientnet_model()
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.warning(f"Could not download model: {e}. Will try to use existing model if available.")
|
| 84 |
+
|
| 85 |
+
# Initialize models at startup (production)
|
| 86 |
+
if os.environ.get("FLASK_ENV") == "production" or os.environ.get("LOAD_MODELS", "true").lower() == "true":
|
| 87 |
+
logger.info("Initializing models...")
|
| 88 |
+
try:
|
| 89 |
+
from app.utils.model_cache import model_cache
|
| 90 |
+
# Only load EfficientNet model
|
| 91 |
+
model_cache.load_efficientnet_model_only(use_optimized=True)
|
| 92 |
+
logger.info("EfficientNet model loaded successfully")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(f"Failed to load models: {e}", exc_info=True)
|
| 95 |
+
# Don't raise here - let the app start and handle errors gracefully
|
| 96 |
+
|
| 97 |
+
return app
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# For backward compatibility
|
| 101 |
+
app = create_app()
|
| 102 |
+
|
app/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (3.47 kB). View file
|
|
|
app/__pycache__/config.cpython-314.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
app/__pycache__/routes.cpython-314.pyc
ADDED
|
Binary file (9.07 kB). View file
|
|
|
app/config.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Application configuration.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Base directory
|
| 9 |
+
BASE_DIR = Path(__file__).resolve().parent.parent
|
| 10 |
+
|
| 11 |
+
# Flask configuration
|
| 12 |
+
SECRET_KEY = os.environ.get("SESSION_SECRET", "dev-secret-key-change-in-production")
|
| 13 |
+
FLASK_ENV = os.environ.get("FLASK_ENV", "development")
|
| 14 |
+
DEBUG = FLASK_ENV != "production"
|
| 15 |
+
|
| 16 |
+
# Upload configuration
|
| 17 |
+
UPLOAD_FOLDER = os.environ.get("UPLOAD_FOLDER", str(BASE_DIR / "uploads"))
|
| 18 |
+
MAX_FILE_SIZE = int(os.environ.get("MAX_FILE_SIZE", 10 * 1024 * 1024)) # 10MB
|
| 19 |
+
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
|
| 20 |
+
|
| 21 |
+
# Model paths
|
| 22 |
+
MODELS_DIR = BASE_DIR / "models"
|
| 23 |
+
OPTIMIZED_MODELS_DIR = MODELS_DIR / "optimized_models"
|
| 24 |
+
RESNET_MODEL_PATH = MODELS_DIR / "resnet_best_model.pth"
|
| 25 |
+
EFFICIENTNET_MODEL_PATH = MODELS_DIR / "efficient_best_model.pth"
|
| 26 |
+
VOCAB_PATH = MODELS_DIR / "vocab.pkl"
|
| 27 |
+
|
| 28 |
+
# Model configuration
|
| 29 |
+
USE_OPTIMIZED_MODELS = os.environ.get("USE_OPTIMIZED_MODELS", "true").lower() == "true"
|
| 30 |
+
LOAD_MODELS_ON_STARTUP = os.environ.get("LOAD_MODELS", "true").lower() == "true"
|
| 31 |
+
|
app/routes.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Application routes.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from flask import Blueprint, render_template, request, jsonify
|
| 10 |
+
from werkzeug.utils import secure_filename
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from app.utils.model_cache import model_cache
|
| 16 |
+
from app.config import MAX_FILE_SIZE, ALLOWED_EXTENSIONS
|
| 17 |
+
|
| 18 |
+
# Import training functions (handle both old and new locations)
|
| 19 |
+
try:
|
| 20 |
+
from training.resnet_train import visualize_attention
|
| 21 |
+
from training.efficient_train import generate_caption
|
| 22 |
+
except ImportError:
|
| 23 |
+
# Fallback for backward compatibility (before reorganization)
|
| 24 |
+
from resnet_train import visualize_attention
|
| 25 |
+
from efficient_train import generate_caption
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
bp = Blueprint('main', __name__)
|
| 30 |
+
|
| 31 |
+
# Image transformation for EfficientNet
|
| 32 |
+
efficientnet_transform = transforms.Compose([
|
| 33 |
+
transforms.Resize(224),
|
| 34 |
+
transforms.CenterCrop(224),
|
| 35 |
+
transforms.ToTensor(),
|
| 36 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def allowed_file(filename):
|
| 41 |
+
"""Check if file extension is allowed."""
|
| 42 |
+
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def validate_file_type(file_path):
|
| 46 |
+
"""Validate file is actually an image (not just extension)."""
|
| 47 |
+
try:
|
| 48 |
+
img = Image.open(file_path)
|
| 49 |
+
img.verify()
|
| 50 |
+
return True
|
| 51 |
+
except Exception:
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@bp.before_request
|
| 56 |
+
def before_request():
|
| 57 |
+
"""Log request start time."""
|
| 58 |
+
request.start_time = time.time()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@bp.after_request
|
| 62 |
+
def after_request(response):
|
| 63 |
+
"""Add security headers and log request duration."""
|
| 64 |
+
# Security headers
|
| 65 |
+
response.headers['X-Content-Type-Options'] = 'nosniff'
|
| 66 |
+
response.headers['X-Frame-Options'] = 'DENY'
|
| 67 |
+
response.headers['X-XSS-Protection'] = '1; mode=block'
|
| 68 |
+
|
| 69 |
+
# Log request
|
| 70 |
+
duration = time.time() - request.start_time
|
| 71 |
+
logger.info(f"{request.method} {request.path} - {response.status_code} - {duration:.3f}s")
|
| 72 |
+
|
| 73 |
+
return response
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@bp.route('/')
|
| 77 |
+
def index():
|
| 78 |
+
"""Serve the main page."""
|
| 79 |
+
return render_template('index.html')
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@bp.route('/health')
|
| 83 |
+
def health_check():
|
| 84 |
+
"""Health check endpoint for load balancers."""
|
| 85 |
+
return jsonify({
|
| 86 |
+
'status': 'healthy',
|
| 87 |
+
'timestamp': datetime.utcnow().isoformat(),
|
| 88 |
+
'models_loaded': {
|
| 89 |
+
'resnet': model_cache.is_resnet_loaded(),
|
| 90 |
+
'efficientnet': model_cache.is_efficientnet_loaded()
|
| 91 |
+
}
|
| 92 |
+
}), 200
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@bp.route('/ready')
|
| 96 |
+
def readiness_check():
|
| 97 |
+
"""Readiness check - ensures models are loaded."""
|
| 98 |
+
if not model_cache.is_resnet_loaded() and not model_cache.is_efficientnet_loaded():
|
| 99 |
+
return jsonify({'status': 'not ready', 'reason': 'models not loaded'}), 503
|
| 100 |
+
return jsonify({'status': 'ready'}), 200
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@bp.route('/upload', methods=['POST'])
|
| 104 |
+
def upload_file():
|
| 105 |
+
"""Handle image upload and generate caption."""
|
| 106 |
+
if 'image' not in request.files:
|
| 107 |
+
logger.warning("Upload request missing 'image' field")
|
| 108 |
+
return jsonify({'error': 'No file part'}), 400
|
| 109 |
+
|
| 110 |
+
file = request.files['image']
|
| 111 |
+
model_choice = request.form.get('model', 'efficientnet') # Default to EfficientNet
|
| 112 |
+
|
| 113 |
+
if file.filename == '':
|
| 114 |
+
return jsonify({'error': 'No selected file'}), 400
|
| 115 |
+
|
| 116 |
+
if not file or not allowed_file(file.filename):
|
| 117 |
+
return jsonify({'error': 'Invalid file type. Only PNG, JPG, JPEG allowed.'}), 400
|
| 118 |
+
|
| 119 |
+
# Get upload folder from current app (set in __init__.py)
|
| 120 |
+
from flask import current_app
|
| 121 |
+
upload_folder = current_app.config['UPLOAD_FOLDER']
|
| 122 |
+
|
| 123 |
+
# Save file temporarily
|
| 124 |
+
filename = secure_filename(file.filename)
|
| 125 |
+
filepath = os.path.join(upload_folder, filename)
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
file.save(filepath)
|
| 129 |
+
|
| 130 |
+
# Validate file size
|
| 131 |
+
file_size = os.path.getsize(filepath)
|
| 132 |
+
if file_size > MAX_FILE_SIZE:
|
| 133 |
+
os.remove(filepath)
|
| 134 |
+
return jsonify({'error': f'File too large. Maximum size: {MAX_FILE_SIZE / 1024 / 1024}MB'}), 400
|
| 135 |
+
|
| 136 |
+
# Validate file is actually an image
|
| 137 |
+
if not validate_file_type(filepath):
|
| 138 |
+
os.remove(filepath)
|
| 139 |
+
return jsonify({'error': 'Invalid image file'}), 400
|
| 140 |
+
|
| 141 |
+
# Generate caption based on model choice
|
| 142 |
+
start_time = time.time()
|
| 143 |
+
|
| 144 |
+
if model_choice == 'efficientnet':
|
| 145 |
+
if not model_cache.is_efficientnet_loaded():
|
| 146 |
+
return jsonify({'error': 'EfficientNet model not available'}), 503
|
| 147 |
+
|
| 148 |
+
model, tokenizer = model_cache.get_efficientnet_model()
|
| 149 |
+
|
| 150 |
+
# Load and preprocess image
|
| 151 |
+
image = Image.open(filepath).convert('RGB')
|
| 152 |
+
image_tensor = efficientnet_transform(image).to(model_cache._device)
|
| 153 |
+
|
| 154 |
+
# Generate caption
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
caption = generate_caption(
|
| 157 |
+
model,
|
| 158 |
+
image_tensor,
|
| 159 |
+
tokenizer,
|
| 160 |
+
model_cache._device,
|
| 161 |
+
max_length=64
|
| 162 |
+
)
|
| 163 |
+
else: # resnet50
|
| 164 |
+
if not model_cache.is_resnet_loaded():
|
| 165 |
+
return jsonify({'error': 'ResNet model not available'}), 503
|
| 166 |
+
|
| 167 |
+
encoder, decoder, vocab = model_cache.get_resnet_models()
|
| 168 |
+
|
| 169 |
+
# Generate caption
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
caption = visualize_attention(filepath, encoder, decoder, model_cache._device)
|
| 172 |
+
|
| 173 |
+
inference_time = time.time() - start_time
|
| 174 |
+
logger.info(f"Caption generated in {inference_time:.3f}s using {model_choice}")
|
| 175 |
+
|
| 176 |
+
# Clean up uploaded file
|
| 177 |
+
os.remove(filepath)
|
| 178 |
+
|
| 179 |
+
return jsonify({
|
| 180 |
+
'success': True,
|
| 181 |
+
'caption': caption,
|
| 182 |
+
'model': model_choice,
|
| 183 |
+
'inference_time': round(inference_time, 3)
|
| 184 |
+
})
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.error(f"Error generating caption: {e}", exc_info=True)
|
| 188 |
+
|
| 189 |
+
# Clean up file on error
|
| 190 |
+
if os.path.exists(filepath):
|
| 191 |
+
os.remove(filepath)
|
| 192 |
+
|
| 193 |
+
return jsonify({'error': 'Failed to generate caption. Please try again.'}), 500
|
| 194 |
+
|
app/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities package."""
|
| 2 |
+
|
app/utils/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
app/utils/__pycache__/model_cache.cpython-314.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
app/utils/model_cache.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Caching Module for Production
|
| 3 |
+
Loads models once at startup and reuses them for all requests.
|
| 4 |
+
This eliminates the overhead of loading models per-request.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import os
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
# Get base directory (project root)
|
| 15 |
+
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
| 16 |
+
MODELS_DIR = BASE_DIR / "models"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ModelCache:
|
| 20 |
+
"""Singleton class to cache loaded models in memory."""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self._resnet_encoder = None
|
| 24 |
+
self._resnet_decoder = None
|
| 25 |
+
self._resnet_vocab = None
|
| 26 |
+
self._efficientnet_model = None
|
| 27 |
+
self._efficientnet_tokenizer = None
|
| 28 |
+
self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 29 |
+
self._models_loaded = False
|
| 30 |
+
|
| 31 |
+
logger.info(f"ModelCache initialized on device: {self._device}")
|
| 32 |
+
|
| 33 |
+
def load_all_models(self,
|
| 34 |
+
resnet_path=None,
|
| 35 |
+
efficientnet_path=None,
|
| 36 |
+
use_optimized=True):
|
| 37 |
+
"""
|
| 38 |
+
Load all models at startup.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
resnet_path: Path to ResNet checkpoint (default: models/resnet_best_model.pth)
|
| 42 |
+
efficientnet_path: Path to EfficientNet checkpoint (default: models/efficient_best_model.pth)
|
| 43 |
+
use_optimized: If True, try to load optimized models first
|
| 44 |
+
"""
|
| 45 |
+
if self._models_loaded:
|
| 46 |
+
logger.warning("Models already loaded, skipping")
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
# Set default paths
|
| 50 |
+
if resnet_path is None:
|
| 51 |
+
resnet_path = str(MODELS_DIR / "resnet_best_model.pth")
|
| 52 |
+
if efficientnet_path is None:
|
| 53 |
+
efficientnet_path = str(MODELS_DIR / "efficient_best_model.pth")
|
| 54 |
+
|
| 55 |
+
# Try optimized models first if requested
|
| 56 |
+
if use_optimized:
|
| 57 |
+
# Check multiple possible locations for optimized models
|
| 58 |
+
optimized_resnet_paths = [
|
| 59 |
+
str(MODELS_DIR / "optimized_models" / "resnet_resnet_best_model_quantized.pth"),
|
| 60 |
+
str(BASE_DIR / "optimized_models" / "resnet_resnet_best_model_quantized.pth"),
|
| 61 |
+
resnet_path.replace('.pth', '_quantized.pth'),
|
| 62 |
+
resnet_path.replace('resnet_best_model.pth', 'resnet_resnet_best_model_quantized.pth'),
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
optimized_efficient_paths = [
|
| 66 |
+
str(MODELS_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
|
| 67 |
+
str(BASE_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
|
| 68 |
+
efficientnet_path.replace('.pth', '_quantized.pth'),
|
| 69 |
+
efficientnet_path.replace('efficient_best_model.pth', 'efficientnet_efficient_best_model_quantized.pth'),
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
# Find optimized ResNet model
|
| 73 |
+
for opt_path in optimized_resnet_paths:
|
| 74 |
+
if os.path.exists(opt_path):
|
| 75 |
+
resnet_path = opt_path
|
| 76 |
+
logger.info(f"Using optimized ResNet model: {resnet_path}")
|
| 77 |
+
break
|
| 78 |
+
|
| 79 |
+
# Find optimized EfficientNet model
|
| 80 |
+
for opt_path in optimized_efficient_paths:
|
| 81 |
+
if os.path.exists(opt_path):
|
| 82 |
+
efficientnet_path = opt_path
|
| 83 |
+
logger.info(f"Using optimized EfficientNet model: {efficientnet_path}")
|
| 84 |
+
break
|
| 85 |
+
|
| 86 |
+
# Load EfficientNet only (ResNet skipped)
|
| 87 |
+
try:
|
| 88 |
+
self.load_efficientnet_model(efficientnet_path)
|
| 89 |
+
logger.info("EfficientNet model loaded successfully")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Failed to load EfficientNet model: {e}", exc_info=True)
|
| 92 |
+
|
| 93 |
+
self._models_loaded = True
|
| 94 |
+
|
| 95 |
+
def load_efficientnet_model_only(self, use_optimized=True):
|
| 96 |
+
"""
|
| 97 |
+
Load only EfficientNet model (skip ResNet).
|
| 98 |
+
Useful when only EfficientNet is needed.
|
| 99 |
+
"""
|
| 100 |
+
if self._models_loaded:
|
| 101 |
+
logger.warning("Models already loaded, skipping")
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
efficientnet_path = str(MODELS_DIR / "efficient_best_model.pth")
|
| 105 |
+
|
| 106 |
+
# Try optimized model first if requested
|
| 107 |
+
if use_optimized:
|
| 108 |
+
optimized_efficient_paths = [
|
| 109 |
+
str(MODELS_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
|
| 110 |
+
str(BASE_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
|
| 111 |
+
efficientnet_path.replace('.pth', '_quantized.pth'),
|
| 112 |
+
efficientnet_path.replace('efficient_best_model.pth', 'efficientnet_efficient_best_model_quantized.pth'),
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
# Find optimized EfficientNet model
|
| 116 |
+
for opt_path in optimized_efficient_paths:
|
| 117 |
+
if os.path.exists(opt_path):
|
| 118 |
+
efficientnet_path = opt_path
|
| 119 |
+
logger.info(f"Using optimized EfficientNet model: {efficientnet_path}")
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
# Load EfficientNet
|
| 123 |
+
try:
|
| 124 |
+
self.load_efficientnet_model(efficientnet_path)
|
| 125 |
+
logger.info("EfficientNet model loaded successfully")
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.error(f"Failed to load EfficientNet model: {e}", exc_info=True)
|
| 128 |
+
|
| 129 |
+
self._models_loaded = True
|
| 130 |
+
|
| 131 |
+
def load_resnet_models(self, checkpoint_path=None):
|
| 132 |
+
"""Load ResNet encoder and decoder models."""
|
| 133 |
+
if self._resnet_encoder is not None:
|
| 134 |
+
return self._resnet_encoder, self._resnet_decoder, self._resnet_vocab
|
| 135 |
+
|
| 136 |
+
if checkpoint_path is None:
|
| 137 |
+
checkpoint_path = str(MODELS_DIR / "resnet_best_model.pth")
|
| 138 |
+
|
| 139 |
+
# Resolve path - try multiple locations
|
| 140 |
+
checkpoint_path = self._resolve_model_path(checkpoint_path)
|
| 141 |
+
|
| 142 |
+
logger.info(f"Loading ResNet models from {checkpoint_path}")
|
| 143 |
+
|
| 144 |
+
# Import from training module (handles both old and new locations)
|
| 145 |
+
# Need to do this BEFORE loading checkpoint to avoid pickle issues
|
| 146 |
+
try:
|
| 147 |
+
from training.resnet_train import EncoderCNN, DecoderRNN
|
| 148 |
+
# Add to sys.modules to help with pickle loading
|
| 149 |
+
import sys
|
| 150 |
+
if 'resnet_train' not in sys.modules:
|
| 151 |
+
sys.modules['resnet_train'] = sys.modules['training.resnet_train']
|
| 152 |
+
except ImportError:
|
| 153 |
+
try:
|
| 154 |
+
# Fallback for backward compatibility
|
| 155 |
+
import sys
|
| 156 |
+
sys.path.insert(0, str(BASE_DIR))
|
| 157 |
+
from resnet_train import EncoderCNN, DecoderRNN
|
| 158 |
+
except ImportError:
|
| 159 |
+
logger.error("Could not import ResNet model classes. Make sure resnet_train.py exists in training/ or root.")
|
| 160 |
+
raise
|
| 161 |
+
|
| 162 |
+
# Load checkpoint with proper module mapping
|
| 163 |
+
import sys
|
| 164 |
+
import importlib.util
|
| 165 |
+
|
| 166 |
+
# Map old module names for pickle compatibility
|
| 167 |
+
if 'resnet_train' not in sys.modules:
|
| 168 |
+
try:
|
| 169 |
+
spec = importlib.util.spec_from_file_location("resnet_train", str(BASE_DIR / "training" / "resnet_train.py"))
|
| 170 |
+
if spec and spec.loader:
|
| 171 |
+
resnet_module = importlib.util.module_from_spec(spec)
|
| 172 |
+
sys.modules['resnet_train'] = resnet_module
|
| 173 |
+
spec.loader.exec_module(resnet_module)
|
| 174 |
+
except Exception:
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False)
|
| 178 |
+
|
| 179 |
+
# Initialize models
|
| 180 |
+
self._resnet_encoder = EncoderCNN().to(self._device)
|
| 181 |
+
self._resnet_decoder = DecoderRNN().to(self._device)
|
| 182 |
+
|
| 183 |
+
# Load weights
|
| 184 |
+
self._resnet_encoder.load_state_dict(checkpoint['encoder'])
|
| 185 |
+
self._resnet_decoder.load_state_dict(checkpoint['decoder'])
|
| 186 |
+
|
| 187 |
+
# Set to eval mode
|
| 188 |
+
self._resnet_encoder.eval()
|
| 189 |
+
self._resnet_decoder.eval()
|
| 190 |
+
|
| 191 |
+
# Store vocabulary
|
| 192 |
+
self._resnet_vocab = checkpoint.get('vocab')
|
| 193 |
+
|
| 194 |
+
# Warm up models (first inference is slower)
|
| 195 |
+
logger.info("Warming up ResNet models...")
|
| 196 |
+
dummy_input = torch.randn(1, 3, 224, 224).to(self._device)
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
_ = self._resnet_encoder(dummy_input)
|
| 199 |
+
logger.info("ResNet models warmed up")
|
| 200 |
+
|
| 201 |
+
return self._resnet_encoder, self._resnet_decoder, self._resnet_vocab
|
| 202 |
+
|
| 203 |
+
def load_efficientnet_model(self, checkpoint_path=None):
|
| 204 |
+
"""Load EfficientNet model."""
|
| 205 |
+
if self._efficientnet_model is not None:
|
| 206 |
+
return self._efficientnet_model, self._efficientnet_tokenizer
|
| 207 |
+
|
| 208 |
+
if checkpoint_path is None:
|
| 209 |
+
checkpoint_path = str(MODELS_DIR / "efficient_best_model.pth")
|
| 210 |
+
|
| 211 |
+
# Resolve path - try multiple locations
|
| 212 |
+
checkpoint_path = self._resolve_model_path(checkpoint_path)
|
| 213 |
+
|
| 214 |
+
logger.info(f"Loading EfficientNet model from {checkpoint_path}")
|
| 215 |
+
|
| 216 |
+
# Import from training module (handles both old and new locations)
|
| 217 |
+
try:
|
| 218 |
+
from training.efficient_train import Encoder, Decoder, ImageCaptioningModel
|
| 219 |
+
except ImportError:
|
| 220 |
+
try:
|
| 221 |
+
# Fallback for backward compatibility
|
| 222 |
+
import sys
|
| 223 |
+
sys.path.insert(0, str(BASE_DIR))
|
| 224 |
+
from efficient_train import Encoder, Decoder, ImageCaptioningModel
|
| 225 |
+
except ImportError:
|
| 226 |
+
logger.error("Could not import EfficientNet model classes. Make sure efficient_train.py exists in training/ or root.")
|
| 227 |
+
raise
|
| 228 |
+
|
| 229 |
+
from transformers import AutoTokenizer
|
| 230 |
+
|
| 231 |
+
# Initialize tokenizer
|
| 232 |
+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
| 233 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 234 |
+
special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
|
| 235 |
+
tokenizer.add_special_tokens(special_tokens)
|
| 236 |
+
self._efficientnet_tokenizer = tokenizer
|
| 237 |
+
|
| 238 |
+
# Initialize model
|
| 239 |
+
encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
|
| 240 |
+
decoder = Decoder(
|
| 241 |
+
vocab_size=len(tokenizer),
|
| 242 |
+
embed_dim=512,
|
| 243 |
+
num_layers=8,
|
| 244 |
+
num_heads=8,
|
| 245 |
+
max_seq_length=64
|
| 246 |
+
)
|
| 247 |
+
self._efficientnet_model = ImageCaptioningModel(encoder, decoder).to(self._device)
|
| 248 |
+
|
| 249 |
+
# Load weights
|
| 250 |
+
checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False)
|
| 251 |
+
|
| 252 |
+
# Check if this is a quantized model (has _packed_params keys)
|
| 253 |
+
is_quantized = any('_packed_params' in key for key in checkpoint.get('model_state', checkpoint).keys())
|
| 254 |
+
|
| 255 |
+
if is_quantized:
|
| 256 |
+
# For quantized models, we need to prepare the model for quantization first
|
| 257 |
+
logger.info("Detected quantized model, preparing model for quantization...")
|
| 258 |
+
try:
|
| 259 |
+
# Prepare model for quantization
|
| 260 |
+
import torch.quantization as quant
|
| 261 |
+
self._efficientnet_model = quant.quantize_dynamic(
|
| 262 |
+
self._efficientnet_model, {torch.nn.Linear}, dtype=torch.qint8
|
| 263 |
+
)
|
| 264 |
+
logger.info("Model prepared for quantization")
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.warning(f"Could not prepare model for quantization: {e}. Trying to load anyway...")
|
| 267 |
+
|
| 268 |
+
if 'model_state' in checkpoint:
|
| 269 |
+
try:
|
| 270 |
+
self._efficientnet_model.load_state_dict(checkpoint['model_state'], strict=False)
|
| 271 |
+
except Exception as e:
|
| 272 |
+
logger.warning(f"Could not load quantized state dict: {e}. Trying regular model...")
|
| 273 |
+
# Try loading non-quantized model instead
|
| 274 |
+
regular_path = checkpoint_path.replace('_quantized.pth', '.pth').replace('efficientnet_efficient_best_model', 'efficient_best_model')
|
| 275 |
+
if os.path.exists(regular_path) and regular_path != checkpoint_path:
|
| 276 |
+
logger.info(f"Trying regular model: {regular_path}")
|
| 277 |
+
checkpoint = torch.load(regular_path, map_location=self._device, weights_only=False)
|
| 278 |
+
if 'model_state' in checkpoint:
|
| 279 |
+
self._efficientnet_model.load_state_dict(checkpoint['model_state'])
|
| 280 |
+
else:
|
| 281 |
+
self._efficientnet_model.load_state_dict(checkpoint)
|
| 282 |
+
else:
|
| 283 |
+
# Fallback: try loading directly
|
| 284 |
+
try:
|
| 285 |
+
self._efficientnet_model.load_state_dict(checkpoint, strict=False)
|
| 286 |
+
except Exception:
|
| 287 |
+
logger.warning("Could not load state dict. Model may not work correctly.")
|
| 288 |
+
|
| 289 |
+
self._efficientnet_model.eval()
|
| 290 |
+
|
| 291 |
+
# Warm up
|
| 292 |
+
logger.info("Warming up EfficientNet model...")
|
| 293 |
+
dummy_input = torch.randn(1, 3, 224, 224).to(self._device)
|
| 294 |
+
with torch.no_grad():
|
| 295 |
+
_ = self._efficientnet_model.encoder(dummy_input)
|
| 296 |
+
logger.info("EfficientNet model warmed up")
|
| 297 |
+
|
| 298 |
+
return self._efficientnet_model, self._efficientnet_tokenizer
|
| 299 |
+
|
| 300 |
+
def _resolve_model_path(self, checkpoint_path):
|
| 301 |
+
"""Resolve model path, trying multiple locations."""
|
| 302 |
+
# If path exists, use it
|
| 303 |
+
if os.path.exists(checkpoint_path):
|
| 304 |
+
return checkpoint_path
|
| 305 |
+
|
| 306 |
+
# Try in models directory
|
| 307 |
+
alt_path = str(MODELS_DIR / os.path.basename(checkpoint_path))
|
| 308 |
+
if os.path.exists(alt_path):
|
| 309 |
+
logger.info(f"Found model at: {alt_path}")
|
| 310 |
+
return alt_path
|
| 311 |
+
|
| 312 |
+
# Try in root directory (backward compatibility)
|
| 313 |
+
alt_path = str(BASE_DIR / os.path.basename(checkpoint_path))
|
| 314 |
+
if os.path.exists(alt_path):
|
| 315 |
+
logger.info(f"Found model at: {alt_path}")
|
| 316 |
+
return alt_path
|
| 317 |
+
|
| 318 |
+
# Return original path (will fail with clear error)
|
| 319 |
+
return checkpoint_path
|
| 320 |
+
|
| 321 |
+
def get_resnet_models(self):
|
| 322 |
+
"""Get cached ResNet models."""
|
| 323 |
+
if self._resnet_encoder is None:
|
| 324 |
+
raise RuntimeError("ResNet models not loaded. Call load_resnet_models() first.")
|
| 325 |
+
return self._resnet_encoder, self._resnet_decoder, self._resnet_vocab
|
| 326 |
+
|
| 327 |
+
def get_efficientnet_model(self):
|
| 328 |
+
"""Get cached EfficientNet model."""
|
| 329 |
+
if self._efficientnet_model is None:
|
| 330 |
+
raise RuntimeError("EfficientNet model not loaded. Call load_efficientnet_model() first.")
|
| 331 |
+
return self._efficientnet_model, self._efficientnet_tokenizer
|
| 332 |
+
|
| 333 |
+
def is_resnet_loaded(self):
|
| 334 |
+
"""Check if ResNet models are loaded."""
|
| 335 |
+
return self._resnet_encoder is not None
|
| 336 |
+
|
| 337 |
+
def is_efficientnet_loaded(self):
|
| 338 |
+
"""Check if EfficientNet model is loaded."""
|
| 339 |
+
return self._efficientnet_model is not None
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# Singleton instance
|
| 343 |
+
model_cache = ModelCache()
|
hf_space_Dockerfile
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for Hugging Face Spaces
|
| 2 |
+
# Based on: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 3 |
+
|
| 4 |
+
FROM python:3.10-slim
|
| 5 |
+
|
| 6 |
+
# Create user (HF Spaces requirement)
|
| 7 |
+
RUN useradd -m -u 1000 user
|
| 8 |
+
USER user
|
| 9 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 10 |
+
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
|
| 13 |
+
# Install system dependencies
|
| 14 |
+
USER root
|
| 15 |
+
RUN apt-get update && apt-get install -y \
|
| 16 |
+
build-essential \
|
| 17 |
+
curl \
|
| 18 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 19 |
+
|
| 20 |
+
USER user
|
| 21 |
+
|
| 22 |
+
# Copy and install Python dependencies
|
| 23 |
+
COPY --chown=user requirements.txt requirements.txt
|
| 24 |
+
RUN pip install --no-cache-dir --user --upgrade -r requirements.txt
|
| 25 |
+
|
| 26 |
+
# Download NLTK data
|
| 27 |
+
RUN python -c "import nltk; nltk.download('punkt', quiet=True)"
|
| 28 |
+
|
| 29 |
+
# Copy application files
|
| 30 |
+
COPY --chown=user app/ /app/app/
|
| 31 |
+
COPY --chown=user training/ /app/training/
|
| 32 |
+
COPY --chown=user scripts/ /app/scripts/
|
| 33 |
+
COPY --chown=user templates/ /app/templates/
|
| 34 |
+
COPY --chown=user static/ /app/static/
|
| 35 |
+
COPY --chown=user app.py /app/
|
| 36 |
+
|
| 37 |
+
# Create necessary directories
|
| 38 |
+
RUN mkdir -p /app/models/optimized_models /app/uploads
|
| 39 |
+
|
| 40 |
+
# HF Spaces uses port 7860
|
| 41 |
+
EXPOSE 7860
|
| 42 |
+
|
| 43 |
+
# Set environment variables
|
| 44 |
+
ENV FLASK_ENV=production
|
| 45 |
+
ENV PORT=7860
|
| 46 |
+
|
| 47 |
+
# Run the application on port 7860 (HF Spaces requirement)
|
| 48 |
+
# Use app.py as entry point (HF Spaces looks for app.py)
|
| 49 |
+
CMD ["gunicorn", "app:app", "--bind", "0.0.0.0:7860", "--workers", "1", "--timeout", "120", "--threads", "2"]
|
| 50 |
+
|
hf_space_app.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Spaces - Flask Application Entry Point
|
| 3 |
+
HF Spaces expects app.py with an 'app' variable
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add project root to path
|
| 11 |
+
BASE_DIR = Path(__file__).resolve().parent
|
| 12 |
+
sys.path.insert(0, str(BASE_DIR))
|
| 13 |
+
|
| 14 |
+
# Import Flask app from app package
|
| 15 |
+
# This will trigger model loading at startup
|
| 16 |
+
from app import app
|
| 17 |
+
|
| 18 |
+
# HF Spaces requires 'app' variable to be available
|
| 19 |
+
# The app is already created in app/__init__.py
|
| 20 |
+
# No need to run it here - Gunicorn will handle it
|
| 21 |
+
|
hf_space_requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
Pillow>=10.0.0
|
| 5 |
+
timm>=0.9.0
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
flask>=2.3.0
|
| 8 |
+
gunicorn>=21.2.0
|
| 9 |
+
werkzeug>=2.3.0
|
| 10 |
+
nltk>=3.8.1
|
| 11 |
+
requests>=2.31.0
|
| 12 |
+
huggingface_hub>=0.20.0
|
| 13 |
+
|
scripts/download_model.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Download EfficientNet model from cloud storage if not present.
|
| 3 |
+
This script runs at application startup to download the model if needed.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import requests
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
def download_efficientnet_model():
|
| 15 |
+
"""
|
| 16 |
+
Download EfficientNet optimized model if it doesn't exist.
|
| 17 |
+
Supports two methods:
|
| 18 |
+
1. Hugging Face Hub (set HF_MODEL_REPO environment variable)
|
| 19 |
+
2. Direct URL download (set EFFICIENTNET_MODEL_URL environment variable)
|
| 20 |
+
"""
|
| 21 |
+
# Get base directory
|
| 22 |
+
base_dir = Path(__file__).resolve().parent.parent
|
| 23 |
+
models_dir = base_dir / "models" / "optimized_models"
|
| 24 |
+
models_dir.mkdir(parents=True, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
model_path = models_dir / "efficientnet_efficient_best_model_quantized.pth"
|
| 27 |
+
|
| 28 |
+
# Check if model already exists
|
| 29 |
+
if model_path.exists():
|
| 30 |
+
size_mb = model_path.stat().st_size / (1024 * 1024)
|
| 31 |
+
logger.info(f"EfficientNet model already exists ({size_mb:.1f}MB)")
|
| 32 |
+
return True
|
| 33 |
+
|
| 34 |
+
# Try Hugging Face Hub first
|
| 35 |
+
hf_repo = os.environ.get("HF_MODEL_REPO")
|
| 36 |
+
if hf_repo:
|
| 37 |
+
try:
|
| 38 |
+
from huggingface_hub import hf_hub_download
|
| 39 |
+
logger.info(f"Downloading model from Hugging Face Hub: {hf_repo}")
|
| 40 |
+
downloaded_path = hf_hub_download(
|
| 41 |
+
repo_id=hf_repo,
|
| 42 |
+
filename="efficientnet_efficient_best_model_quantized.pth",
|
| 43 |
+
cache_dir=str(models_dir),
|
| 44 |
+
local_dir=str(models_dir),
|
| 45 |
+
local_dir_use_symlinks=False
|
| 46 |
+
)
|
| 47 |
+
# Move to expected location if needed
|
| 48 |
+
if downloaded_path != str(model_path):
|
| 49 |
+
import shutil
|
| 50 |
+
shutil.move(downloaded_path, model_path)
|
| 51 |
+
size_mb = model_path.stat().st_size / (1024 * 1024)
|
| 52 |
+
logger.info(f"Model downloaded from HF Hub successfully ({size_mb:.1f}MB)")
|
| 53 |
+
return True
|
| 54 |
+
except ImportError:
|
| 55 |
+
logger.warning("huggingface_hub not installed. Install with: pip install huggingface_hub")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.warning(f"Failed to download from HF Hub: {e}. Trying direct URL...")
|
| 58 |
+
|
| 59 |
+
# Fallback to direct URL download
|
| 60 |
+
model_url = os.environ.get("EFFICIENTNET_MODEL_URL")
|
| 61 |
+
|
| 62 |
+
if not model_url:
|
| 63 |
+
logger.warning("Neither HF_MODEL_REPO nor EFFICIENTNET_MODEL_URL is set.")
|
| 64 |
+
logger.warning("Model will not be downloaded. Set one of these environment variables.")
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
logger.info(f"Downloading EfficientNet model from {model_url}...")
|
| 69 |
+
logger.info("This may take a few minutes (model is ~245MB)...")
|
| 70 |
+
|
| 71 |
+
# Download with progress
|
| 72 |
+
response = requests.get(model_url, stream=True, timeout=300)
|
| 73 |
+
response.raise_for_status()
|
| 74 |
+
|
| 75 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 76 |
+
downloaded = 0
|
| 77 |
+
|
| 78 |
+
with open(model_path, 'wb') as f:
|
| 79 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 80 |
+
if chunk:
|
| 81 |
+
f.write(chunk)
|
| 82 |
+
downloaded += len(chunk)
|
| 83 |
+
if total_size > 0:
|
| 84 |
+
percent = (downloaded / total_size) * 100
|
| 85 |
+
if downloaded % (10 * 1024 * 1024) == 0: # Log every 10MB
|
| 86 |
+
logger.info(f"Downloaded {downloaded / (1024 * 1024):.1f}MB / {total_size / (1024 * 1024):.1f}MB ({percent:.1f}%)")
|
| 87 |
+
|
| 88 |
+
size_mb = model_path.stat().st_size / (1024 * 1024)
|
| 89 |
+
logger.info(f"EfficientNet model downloaded successfully ({size_mb:.1f}MB)")
|
| 90 |
+
return True
|
| 91 |
+
|
| 92 |
+
except requests.exceptions.RequestException as e:
|
| 93 |
+
logger.error(f"Failed to download model: {e}")
|
| 94 |
+
# Clean up partial download
|
| 95 |
+
if model_path.exists():
|
| 96 |
+
model_path.unlink()
|
| 97 |
+
return False
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"Error downloading model: {e}", exc_info=True)
|
| 100 |
+
# Clean up partial download
|
| 101 |
+
if model_path.exists():
|
| 102 |
+
model_path.unlink()
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
# Configure logging
|
| 107 |
+
logging.basicConfig(
|
| 108 |
+
level=logging.INFO,
|
| 109 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 110 |
+
)
|
| 111 |
+
download_efficientnet_model()
|
| 112 |
+
|
scripts/efficient_caption.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from efficient_train import Encoder, Decoder, ImageCaptioningModel, generate_caption
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# Configuration
|
| 11 |
+
MODEL_PATH = 'efficient_best_model.pth' # Path to your saved model
|
| 12 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 13 |
+
MAX_SEQ_LENGTH = 64 # Ensure this matches the value used during training
|
| 14 |
+
|
| 15 |
+
# Image transformation (ensure it matches the preprocessing used during training)
|
| 16 |
+
transform = transforms.Compose([
|
| 17 |
+
transforms.Resize(224),
|
| 18 |
+
transforms.CenterCrop(224),
|
| 19 |
+
transforms.ToTensor(),
|
| 20 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 21 |
+
])
|
| 22 |
+
|
| 23 |
+
# Load the tokenizer
|
| 24 |
+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
| 25 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 26 |
+
special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
|
| 27 |
+
tokenizer.add_special_tokens(special_tokens)
|
| 28 |
+
|
| 29 |
+
# Initialize the model components
|
| 30 |
+
encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
|
| 31 |
+
decoder = Decoder(
|
| 32 |
+
vocab_size=len(tokenizer),
|
| 33 |
+
embed_dim=512,
|
| 34 |
+
num_layers=8,
|
| 35 |
+
num_heads=8,
|
| 36 |
+
max_seq_length=MAX_SEQ_LENGTH
|
| 37 |
+
)
|
| 38 |
+
model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
|
| 39 |
+
|
| 40 |
+
# Load the trained model weights
|
| 41 |
+
if not os.path.exists(MODEL_PATH):
|
| 42 |
+
raise FileNotFoundError(f"Model file not found at: {MODEL_PATH}. Please ensure you have a trained model checkpoint at this location.")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Add a check for the size of the file
|
| 46 |
+
if os.path.getsize(MODEL_PATH) == 0:
|
| 47 |
+
raise ValueError(f"Model file at {MODEL_PATH} is empty. Please check the saved model.")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
|
| 51 |
+
|
| 52 |
+
# Check if the checkpoint has the model_state key
|
| 53 |
+
if 'model_state' not in checkpoint:
|
| 54 |
+
raise KeyError("The checkpoint file does not contain the key 'model_state'. Please ensure the model was saved correctly using 'torch.save(model.state_dict(), path)'.")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
model.load_state_dict(checkpoint['model_state'])
|
| 58 |
+
model.eval()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def caption(image_path):
|
| 63 |
+
|
| 64 |
+
# Load and preprocess the image
|
| 65 |
+
image = Image.open(image_path).convert('RGB')
|
| 66 |
+
image = transform(image).to(DEVICE)
|
| 67 |
+
|
| 68 |
+
# Generate caption
|
| 69 |
+
caption1 = generate_caption(model, image, tokenizer, DEVICE, max_length=MAX_SEQ_LENGTH)
|
| 70 |
+
return caption1
|
| 71 |
+
|
| 72 |
+
if __name__ == '__main__':
|
| 73 |
+
parser = argparse.ArgumentParser(description="Generate a caption for the provided image.")
|
| 74 |
+
parser.add_argument('--image_dir', type=str, required=True, help="Path to the input image file")
|
| 75 |
+
args = parser.parse_args()
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
result = caption(args.image_dir)
|
| 79 |
+
print(result)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logging.error(f"Error generating caption: {str(e)}")
|
| 82 |
+
exit(1)
|
scripts/optimize_models.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Optimization Script for Production Deployment
|
| 3 |
+
Reduces model size and improves inference speed through:
|
| 4 |
+
1. Quantization (INT8)
|
| 5 |
+
2. TorchScript compilation
|
| 6 |
+
3. Model pruning (optional)
|
| 7 |
+
4. State dict optimization
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# Import model classes BEFORE loading checkpoints (needed for unpickling)
|
| 16 |
+
# This ensures PyTorch can find the class definitions when loading saved objects
|
| 17 |
+
# Note: resnet_train.py has module-level code that loads COCO data, which may fail
|
| 18 |
+
# if training files aren't present. We'll handle this in the functions.
|
| 19 |
+
|
| 20 |
+
def quantize_model(checkpoint_path, output_path, model_type='resnet'):
|
| 21 |
+
"""
|
| 22 |
+
Quantize model to INT8 for 4x size reduction and faster inference.
|
| 23 |
+
Note: Slight accuracy loss (usually <1%)
|
| 24 |
+
"""
|
| 25 |
+
print(f"Quantizing {model_type} model...")
|
| 26 |
+
|
| 27 |
+
device = torch.device('cpu') # Quantization typically done on CPU
|
| 28 |
+
|
| 29 |
+
# Import classes before loading (required for unpickling)
|
| 30 |
+
# resnet_train.py now handles missing training data gracefully
|
| 31 |
+
if model_type == 'resnet':
|
| 32 |
+
# Import the module itself so we can update vocab later
|
| 33 |
+
import resnet_train
|
| 34 |
+
from resnet_train import EncoderCNN, DecoderRNN, Vocabulary
|
| 35 |
+
|
| 36 |
+
# Make Vocabulary available in __main__ for unpickling
|
| 37 |
+
# This handles cases where checkpoint was saved with Vocabulary from __main__
|
| 38 |
+
import __main__
|
| 39 |
+
if not hasattr(__main__, 'Vocabulary'):
|
| 40 |
+
__main__.Vocabulary = Vocabulary
|
| 41 |
+
elif model_type == 'efficientnet':
|
| 42 |
+
from efficient_train import Encoder, Decoder, ImageCaptioningModel
|
| 43 |
+
from transformers import AutoTokenizer
|
| 44 |
+
|
| 45 |
+
# Load checkpoint (now all classes are available for unpickling)
|
| 46 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 47 |
+
|
| 48 |
+
if model_type == 'resnet':
|
| 49 |
+
# For ResNet, quantize encoder and decoder separately
|
| 50 |
+
|
| 51 |
+
# IMPORTANT: Update vocab from checkpoint before creating DecoderRNN
|
| 52 |
+
# The decoder uses len(vocab.word2idx) in its __init__, so we need the full vocab
|
| 53 |
+
if 'vocab' in checkpoint and checkpoint['vocab'] is not None:
|
| 54 |
+
# Update the vocab in resnet_train module (DecoderRNN.__init__ references resnet_train.vocab)
|
| 55 |
+
resnet_train.vocab = checkpoint['vocab']
|
| 56 |
+
print(f" Updated vocab size: {len(checkpoint['vocab'].word2idx)}")
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError("Checkpoint does not contain 'vocab' key. Cannot proceed.")
|
| 59 |
+
|
| 60 |
+
encoder = EncoderCNN()
|
| 61 |
+
decoder = DecoderRNN() # Now uses the correct vocab size from checkpoint
|
| 62 |
+
|
| 63 |
+
encoder.load_state_dict(checkpoint['encoder'])
|
| 64 |
+
decoder.load_state_dict(checkpoint['decoder'])
|
| 65 |
+
|
| 66 |
+
# Set to eval mode
|
| 67 |
+
encoder.eval()
|
| 68 |
+
decoder.eval()
|
| 69 |
+
|
| 70 |
+
# Prepare for quantization (dummy input)
|
| 71 |
+
dummy_input = torch.randn(1, 3, 224, 224)
|
| 72 |
+
|
| 73 |
+
# Quantize encoder (only Linear and Conv2d layers)
|
| 74 |
+
encoder_quantized = torch.quantization.quantize_dynamic(
|
| 75 |
+
encoder, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Quantize decoder (only Linear layers - Embedding requires special config)
|
| 79 |
+
# Embeddings are typically small and don't benefit much from quantization
|
| 80 |
+
decoder_quantized = torch.quantization.quantize_dynamic(
|
| 81 |
+
decoder, {torch.nn.Linear}, dtype=torch.qint8
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Save quantized model
|
| 85 |
+
quantized_checkpoint = {
|
| 86 |
+
'encoder': encoder_quantized.state_dict(),
|
| 87 |
+
'decoder': decoder_quantized.state_dict(),
|
| 88 |
+
'vocab': checkpoint.get('vocab'),
|
| 89 |
+
'quantized': True
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
elif model_type == 'efficientnet':
|
| 93 |
+
# Classes already imported above before loading checkpoint
|
| 94 |
+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
| 95 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 96 |
+
special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
|
| 97 |
+
tokenizer.add_special_tokens(special_tokens)
|
| 98 |
+
|
| 99 |
+
encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
|
| 100 |
+
decoder = Decoder(
|
| 101 |
+
vocab_size=len(tokenizer),
|
| 102 |
+
embed_dim=512,
|
| 103 |
+
num_layers=8,
|
| 104 |
+
num_heads=8,
|
| 105 |
+
max_seq_length=64
|
| 106 |
+
)
|
| 107 |
+
model = ImageCaptioningModel(encoder, decoder)
|
| 108 |
+
|
| 109 |
+
# Load state dict - handle both 'model_state' key and direct state dict
|
| 110 |
+
if 'model_state' in checkpoint:
|
| 111 |
+
model.load_state_dict(checkpoint['model_state'])
|
| 112 |
+
else:
|
| 113 |
+
model.load_state_dict(checkpoint)
|
| 114 |
+
|
| 115 |
+
model.eval()
|
| 116 |
+
|
| 117 |
+
# Quantize the full model
|
| 118 |
+
model_quantized = torch.quantization.quantize_dynamic(
|
| 119 |
+
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
quantized_checkpoint = {
|
| 123 |
+
'model_state': model_quantized.state_dict(),
|
| 124 |
+
'quantized': True
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
torch.save(quantized_checkpoint, output_path)
|
| 128 |
+
|
| 129 |
+
# Compare sizes
|
| 130 |
+
original_size = os.path.getsize(checkpoint_path) / (1024 * 1024) # MB
|
| 131 |
+
quantized_size = os.path.getsize(output_path) / (1024 * 1024) # MB
|
| 132 |
+
reduction = (1 - quantized_size / original_size) * 100
|
| 133 |
+
|
| 134 |
+
print(f"✓ Quantization complete!")
|
| 135 |
+
print(f" Original size: {original_size:.2f} MB")
|
| 136 |
+
print(f" Quantized size: {quantized_size:.2f} MB")
|
| 137 |
+
print(f" Size reduction: {reduction:.1f}%")
|
| 138 |
+
|
| 139 |
+
return output_path
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def optimize_state_dict(checkpoint_path, output_path):
|
| 143 |
+
"""
|
| 144 |
+
Remove unnecessary metadata and optimize state dict for smaller size.
|
| 145 |
+
"""
|
| 146 |
+
print(f"Optimizing state dict...")
|
| 147 |
+
|
| 148 |
+
# Import classes before loading (required for unpickling)
|
| 149 |
+
try:
|
| 150 |
+
from resnet_train import Vocabulary
|
| 151 |
+
# Make Vocabulary available in __main__ for unpickling
|
| 152 |
+
import __main__
|
| 153 |
+
if not hasattr(__main__, 'Vocabulary'):
|
| 154 |
+
__main__.Vocabulary = Vocabulary
|
| 155 |
+
except ImportError:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 159 |
+
|
| 160 |
+
# Create optimized checkpoint with only essential data
|
| 161 |
+
optimized = {}
|
| 162 |
+
for key, value in checkpoint.items():
|
| 163 |
+
if key not in ['optimizer', 'scheduler', 'epoch', 'loss', 'metrics']:
|
| 164 |
+
optimized[key] = value
|
| 165 |
+
|
| 166 |
+
# Save with highest compression
|
| 167 |
+
torch.save(optimized, output_path, _use_new_zipfile_serialization=True)
|
| 168 |
+
|
| 169 |
+
original_size = os.path.getsize(checkpoint_path) / (1024 * 1024)
|
| 170 |
+
optimized_size = os.path.getsize(output_path) / (1024 * 1024)
|
| 171 |
+
reduction = (1 - optimized_size / original_size) * 100
|
| 172 |
+
|
| 173 |
+
print(f"✓ State dict optimized!")
|
| 174 |
+
print(f" Original: {original_size:.2f} MB")
|
| 175 |
+
print(f" Optimized: {optimized_size:.2f} MB")
|
| 176 |
+
print(f" Reduction: {reduction:.1f}%")
|
| 177 |
+
|
| 178 |
+
return output_path
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def create_torchscript(checkpoint_path, output_path, model_type='resnet'):
|
| 182 |
+
"""
|
| 183 |
+
Convert model to TorchScript for faster loading and inference.
|
| 184 |
+
Note: Requires example input for tracing.
|
| 185 |
+
"""
|
| 186 |
+
print(f"Creating TorchScript model...")
|
| 187 |
+
|
| 188 |
+
device = torch.device('cpu')
|
| 189 |
+
|
| 190 |
+
# Import classes before loading (required for unpickling)
|
| 191 |
+
if model_type == 'resnet':
|
| 192 |
+
import resnet_train
|
| 193 |
+
from resnet_train import EncoderCNN, DecoderRNN, Vocabulary
|
| 194 |
+
|
| 195 |
+
# Make Vocabulary available in __main__ for unpickling
|
| 196 |
+
import __main__
|
| 197 |
+
if not hasattr(__main__, 'Vocabulary'):
|
| 198 |
+
__main__.Vocabulary = Vocabulary
|
| 199 |
+
elif model_type == 'efficientnet':
|
| 200 |
+
from efficient_train import Encoder, Decoder, ImageCaptioningModel
|
| 201 |
+
from transformers import AutoTokenizer
|
| 202 |
+
|
| 203 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 204 |
+
|
| 205 |
+
if model_type == 'resnet':
|
| 206 |
+
# Update vocab from checkpoint before creating DecoderRNN
|
| 207 |
+
if 'vocab' in checkpoint and checkpoint['vocab'] is not None:
|
| 208 |
+
resnet_train.vocab = checkpoint['vocab']
|
| 209 |
+
print(f" Updated vocab size: {len(checkpoint['vocab'].word2idx)}")
|
| 210 |
+
else:
|
| 211 |
+
raise ValueError("Checkpoint does not contain 'vocab' key. Cannot proceed.")
|
| 212 |
+
|
| 213 |
+
encoder = EncoderCNN().eval()
|
| 214 |
+
decoder = DecoderRNN().eval() # Now uses the correct vocab size
|
| 215 |
+
|
| 216 |
+
encoder.load_state_dict(checkpoint['encoder'])
|
| 217 |
+
decoder.load_state_dict(checkpoint['decoder'])
|
| 218 |
+
|
| 219 |
+
# Trace encoder
|
| 220 |
+
dummy_image = torch.randn(1, 3, 224, 224)
|
| 221 |
+
encoder_traced = torch.jit.trace(encoder, dummy_image)
|
| 222 |
+
|
| 223 |
+
# For decoder, we need to trace with proper inputs
|
| 224 |
+
# This is more complex due to RNN structure
|
| 225 |
+
print(" ⚠ TorchScript for RNN decoder may require manual scripting")
|
| 226 |
+
print(" ✓ Encoder traced successfully")
|
| 227 |
+
|
| 228 |
+
torch.jit.save(encoder_traced, output_path.replace('.pth', '_encoder.pt'))
|
| 229 |
+
|
| 230 |
+
elif model_type == 'efficientnet':
|
| 231 |
+
# Classes already imported above
|
| 232 |
+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
| 233 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 234 |
+
special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
|
| 235 |
+
tokenizer.add_special_tokens(special_tokens)
|
| 236 |
+
|
| 237 |
+
encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
|
| 238 |
+
decoder = Decoder(
|
| 239 |
+
vocab_size=len(tokenizer),
|
| 240 |
+
embed_dim=512,
|
| 241 |
+
num_layers=8,
|
| 242 |
+
num_heads=8,
|
| 243 |
+
max_seq_length=64
|
| 244 |
+
)
|
| 245 |
+
model = ImageCaptioningModel(encoder, decoder).eval()
|
| 246 |
+
|
| 247 |
+
model.load_state_dict(checkpoint['model_state'])
|
| 248 |
+
|
| 249 |
+
# Trace encoder only (decoder has dynamic inputs)
|
| 250 |
+
dummy_image = torch.randn(1, 3, 224, 224)
|
| 251 |
+
encoder_traced = torch.jit.trace(model.encoder, dummy_image)
|
| 252 |
+
|
| 253 |
+
torch.jit.save(encoder_traced, output_path.replace('.pth', '_encoder.pt'))
|
| 254 |
+
print(" ✓ Encoder traced successfully")
|
| 255 |
+
|
| 256 |
+
print(f"✓ TorchScript saved to {output_path}")
|
| 257 |
+
return output_path
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def main():
|
| 261 |
+
parser = argparse.ArgumentParser(description='Optimize models for production deployment')
|
| 262 |
+
parser.add_argument('--model', type=str, choices=['resnet', 'efficientnet', 'both'],
|
| 263 |
+
default='both', help='Model to optimize')
|
| 264 |
+
parser.add_argument('--method', type=str, choices=['quantize', 'optimize', 'torchscript', 'all'],
|
| 265 |
+
default='all', help='Optimization method')
|
| 266 |
+
parser.add_argument('--resnet-path', type=str, default='resnet_best_model.pth',
|
| 267 |
+
help='Path to ResNet checkpoint')
|
| 268 |
+
parser.add_argument('--efficientnet-path', type=str, default='efficient_best_model.pth',
|
| 269 |
+
help='Path to EfficientNet checkpoint')
|
| 270 |
+
parser.add_argument('--output-dir', type=str, default='optimized_models',
|
| 271 |
+
help='Output directory for optimized models')
|
| 272 |
+
|
| 273 |
+
args = parser.parse_args()
|
| 274 |
+
|
| 275 |
+
# Create output directory
|
| 276 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 277 |
+
|
| 278 |
+
models_to_process = []
|
| 279 |
+
if args.model in ['resnet', 'both']:
|
| 280 |
+
if os.path.exists(args.resnet_path):
|
| 281 |
+
models_to_process.append(('resnet', args.resnet_path))
|
| 282 |
+
else:
|
| 283 |
+
print(f"⚠ Warning: {args.resnet_path} not found, skipping ResNet")
|
| 284 |
+
|
| 285 |
+
if args.model in ['efficientnet', 'both']:
|
| 286 |
+
if os.path.exists(args.efficientnet_path):
|
| 287 |
+
models_to_process.append(('efficientnet', args.efficientnet_path))
|
| 288 |
+
else:
|
| 289 |
+
print(f"⚠ Warning: {args.efficientnet_path} not found, skipping EfficientNet")
|
| 290 |
+
|
| 291 |
+
if not models_to_process:
|
| 292 |
+
print("❌ No models found to optimize!")
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
for model_type, model_path in models_to_process:
|
| 296 |
+
print(f"\n{'='*60}")
|
| 297 |
+
print(f"Processing {model_type.upper()} model")
|
| 298 |
+
print(f"{'='*60}")
|
| 299 |
+
|
| 300 |
+
base_name = Path(model_path).stem
|
| 301 |
+
output_base = os.path.join(args.output_dir, f"{model_type}_{base_name}")
|
| 302 |
+
|
| 303 |
+
if args.method in ['quantize', 'all']:
|
| 304 |
+
quantized_path = f"{output_base}_quantized.pth"
|
| 305 |
+
quantize_model(model_path, quantized_path, model_type)
|
| 306 |
+
|
| 307 |
+
if args.method in ['optimize', 'all']:
|
| 308 |
+
optimized_path = f"{output_base}_optimized.pth"
|
| 309 |
+
optimize_state_dict(model_path, optimized_path)
|
| 310 |
+
|
| 311 |
+
if args.method in ['torchscript', 'all']:
|
| 312 |
+
torchscript_path = f"{output_base}_torchscript.pt"
|
| 313 |
+
create_torchscript(model_path, torchscript_path, model_type)
|
| 314 |
+
|
| 315 |
+
print(f"\n{'='*60}")
|
| 316 |
+
print("✓ Optimization complete!")
|
| 317 |
+
print(f"Optimized models saved to: {args.output_dir}")
|
| 318 |
+
print(f"{'='*60}")
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == '__main__':
|
| 322 |
+
main()
|
| 323 |
+
|
scripts/resnet_caption.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import nltk
|
| 6 |
+
nltk.download('punkt', quiet=True)
|
| 7 |
+
|
| 8 |
+
# Import the necessary components from resnet_train.py
|
| 9 |
+
from resnet_train import EncoderCNN, DecoderRNN, visualize_attention, CONFIG, Vocabulary
|
| 10 |
+
import resnet_train # To update its global vocab variable
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
parser = argparse.ArgumentParser(description="Generate image caption from a trained model.")
|
| 14 |
+
parser.add_argument("--image", type=str, required=True, help="Path to the input image")
|
| 15 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained model checkpoint")
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
# Load checkpoint
|
| 21 |
+
checkpoint = torch.load(args.checkpoint, map_location=device,weights_only=False)
|
| 22 |
+
|
| 23 |
+
# Initialize models
|
| 24 |
+
encoder = EncoderCNN().to(device)
|
| 25 |
+
decoder = DecoderRNN().to(device)
|
| 26 |
+
|
| 27 |
+
# Load state dictionaries
|
| 28 |
+
encoder.load_state_dict(checkpoint['encoder'])
|
| 29 |
+
decoder.load_state_dict(checkpoint['decoder'])
|
| 30 |
+
|
| 31 |
+
# Update the global vocabulary from the checkpoint
|
| 32 |
+
resnet_train.vocab = checkpoint['vocab']
|
| 33 |
+
|
| 34 |
+
# Generate caption using the provided image path
|
| 35 |
+
caption = visualize_attention(args.image, encoder, decoder, device)
|
| 36 |
+
print(caption)
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
main()
|
static/css/custom.css
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.card {
|
| 2 |
+
border-radius: 1rem;
|
| 3 |
+
box-shadow: 0 0.5rem 1rem rgba(0, 0, 0, 0.15);
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
.card-header {
|
| 7 |
+
border-top-left-radius: 1rem !important;
|
| 8 |
+
border-top-right-radius: 1rem !important;
|
| 9 |
+
background-color: var(--bs-dark);
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
#previewImage {
|
| 13 |
+
max-height: 400px;
|
| 14 |
+
width: auto;
|
| 15 |
+
object-fit: contain;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
.form-check {
|
| 19 |
+
margin-bottom: 0.5rem;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
.alert {
|
| 23 |
+
margin-bottom: 0;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
.btn-primary {
|
| 27 |
+
padding: 0.5rem 1.5rem;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/* Custom upload button styling */
|
| 31 |
+
.upload-container {
|
| 32 |
+
position: relative;
|
| 33 |
+
width: 120px;
|
| 34 |
+
height: 42px;
|
| 35 |
+
margin: 0 auto;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
.upload-container input[type="file"] {
|
| 39 |
+
display: none;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
.upload-app {
|
| 43 |
+
display: block;
|
| 44 |
+
position: relative;
|
| 45 |
+
width: 120px;
|
| 46 |
+
height: 42px;
|
| 47 |
+
transition: 0.3s ease width;
|
| 48 |
+
cursor: pointer;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
.upload-btn {
|
| 52 |
+
position: absolute;
|
| 53 |
+
top: 0;
|
| 54 |
+
right: 0;
|
| 55 |
+
bottom: 0;
|
| 56 |
+
left: 0;
|
| 57 |
+
background-color: var(--bs-dark);
|
| 58 |
+
border: 2px solid var(--bs-border-color);
|
| 59 |
+
border-radius: 0.375rem;
|
| 60 |
+
overflow: hidden;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
.upload-btn:before {
|
| 64 |
+
content: "Upload";
|
| 65 |
+
position: absolute;
|
| 66 |
+
top: 50%;
|
| 67 |
+
left: 45%;
|
| 68 |
+
transform: translate(-50%, -50%);
|
| 69 |
+
color: var(--bs-body-color);
|
| 70 |
+
font-size: 14px;
|
| 71 |
+
font-weight: bold;
|
| 72 |
+
transition: opacity 0.3s ease;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
.file-selected .upload-btn:before {
|
| 76 |
+
opacity: 0;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
.upload-arrow {
|
| 80 |
+
position: absolute;
|
| 81 |
+
top: 0;
|
| 82 |
+
right: 0;
|
| 83 |
+
width: 38px;
|
| 84 |
+
height: 38px;
|
| 85 |
+
background-color: var(--bs-dark);
|
| 86 |
+
transition: opacity 0.3s ease;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
.file-selected .upload-arrow {
|
| 90 |
+
opacity: 0;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.upload-arrow:before,
|
| 94 |
+
.upload-arrow:after {
|
| 95 |
+
content: "";
|
| 96 |
+
position: absolute;
|
| 97 |
+
top: 18px;
|
| 98 |
+
width: 10px;
|
| 99 |
+
height: 2px;
|
| 100 |
+
background-color: var(--bs-body-color);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
.upload-arrow:before {
|
| 104 |
+
right: 17px;
|
| 105 |
+
transform: rotateZ(-45deg);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
.upload-arrow:after {
|
| 109 |
+
right: 11px;
|
| 110 |
+
transform: rotateZ(45deg);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
.upload-success {
|
| 114 |
+
position: absolute;
|
| 115 |
+
top: 50%;
|
| 116 |
+
left: 50%;
|
| 117 |
+
width: 24px;
|
| 118 |
+
height: 24px;
|
| 119 |
+
margin: 0;
|
| 120 |
+
background-color: var(--bs-success);
|
| 121 |
+
transform: translate(-50%, -50%) scale(0);
|
| 122 |
+
border-radius: 50%;
|
| 123 |
+
opacity: 0;
|
| 124 |
+
transition: transform 0.3s ease, opacity 0.3s ease;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
.upload-success i {
|
| 128 |
+
font-size: 16px;
|
| 129 |
+
color: #fff;
|
| 130 |
+
position: absolute;
|
| 131 |
+
top: 50%;
|
| 132 |
+
left: 50%;
|
| 133 |
+
transform: translate(-50%, -50%) scale(0);
|
| 134 |
+
transition: transform 0.3s ease 0.1s;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
.file-selected .upload-success {
|
| 138 |
+
transform: translate(-50%, -50%) scale(1);
|
| 139 |
+
opacity: 1;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
.file-selected .upload-success i {
|
| 143 |
+
transform: translate(-50%, -50%) scale(1);
|
| 144 |
+
}
|
static/js/main.js
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
document.addEventListener('DOMContentLoaded', function() {
|
| 2 |
+
const form = document.getElementById('uploadForm');
|
| 3 |
+
const imageInput = document.getElementById('imageInput');
|
| 4 |
+
const submitBtn = document.getElementById('submitBtn');
|
| 5 |
+
const spinner = submitBtn.querySelector('.spinner-border');
|
| 6 |
+
const resultSection = document.getElementById('resultSection');
|
| 7 |
+
const previewImage = document.getElementById('previewImage');
|
| 8 |
+
const captionText = document.getElementById('captionText');
|
| 9 |
+
const errorAlert = document.getElementById('errorAlert');
|
| 10 |
+
const uploadApp = document.querySelector('.upload-app');
|
| 11 |
+
|
| 12 |
+
// Preview image when selected
|
| 13 |
+
imageInput.addEventListener('change', function(e) {
|
| 14 |
+
const file = e.target.files[0];
|
| 15 |
+
if (file) {
|
| 16 |
+
const reader = new FileReader();
|
| 17 |
+
reader.onload = function(e) {
|
| 18 |
+
previewImage.src = e.target.result;
|
| 19 |
+
resultSection.classList.remove('d-none');
|
| 20 |
+
captionText.textContent = '';
|
| 21 |
+
errorAlert.classList.add('d-none');
|
| 22 |
+
|
| 23 |
+
// Add success animation class
|
| 24 |
+
uploadApp.classList.add('file-selected');
|
| 25 |
+
};
|
| 26 |
+
reader.readAsDataURL(file);
|
| 27 |
+
}
|
| 28 |
+
});
|
| 29 |
+
|
| 30 |
+
form.addEventListener('submit', async function(e) {
|
| 31 |
+
e.preventDefault();
|
| 32 |
+
|
| 33 |
+
const formData = new FormData();
|
| 34 |
+
const file = imageInput.files[0];
|
| 35 |
+
|
| 36 |
+
if (!file) {
|
| 37 |
+
showError('Please select an image first.');
|
| 38 |
+
return;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
// Add file and selected model to form data
|
| 42 |
+
formData.append('image', file);
|
| 43 |
+
formData.append('model', document.querySelector('input[name="model"]:checked').value);
|
| 44 |
+
|
| 45 |
+
// Show loading state
|
| 46 |
+
setLoading(true);
|
| 47 |
+
|
| 48 |
+
try {
|
| 49 |
+
const response = await fetch('/upload', {
|
| 50 |
+
method: 'POST',
|
| 51 |
+
body: formData
|
| 52 |
+
});
|
| 53 |
+
|
| 54 |
+
const data = await response.json();
|
| 55 |
+
|
| 56 |
+
if (!response.ok) {
|
| 57 |
+
throw new Error(data.error || 'Failed to generate caption');
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
// Display the caption
|
| 61 |
+
captionText.textContent = data.caption;
|
| 62 |
+
resultSection.classList.remove('d-none');
|
| 63 |
+
errorAlert.classList.add('d-none');
|
| 64 |
+
|
| 65 |
+
} catch (error) {
|
| 66 |
+
showError(error.message || 'An error occurred while generating the caption');
|
| 67 |
+
} finally {
|
| 68 |
+
setLoading(false);
|
| 69 |
+
}
|
| 70 |
+
});
|
| 71 |
+
|
| 72 |
+
function setLoading(isLoading) {
|
| 73 |
+
submitBtn.disabled = isLoading;
|
| 74 |
+
spinner.classList.toggle('d-none', !isLoading);
|
| 75 |
+
submitBtn.textContent = isLoading ? ' Processing...' : 'Generate Caption';
|
| 76 |
+
if (isLoading) {
|
| 77 |
+
submitBtn.prepend(spinner);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
function showError(message) {
|
| 82 |
+
errorAlert.textContent = message;
|
| 83 |
+
errorAlert.classList.remove('d-none');
|
| 84 |
+
}
|
| 85 |
+
});
|
templates/index.html
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en" data-bs-theme="dark">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Image Caption Generator</title>
|
| 7 |
+
<link href="https://cdn.replit.com/agent/bootstrap-agent-dark-theme.min.css" rel="stylesheet">
|
| 8 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.7.2/font/bootstrap-icons.css" rel="stylesheet">
|
| 9 |
+
<link href="{{ url_for('static', filename='css/custom.css') }}" rel="stylesheet">
|
| 10 |
+
</head>
|
| 11 |
+
<body>
|
| 12 |
+
<div class="container py-5">
|
| 13 |
+
<div class="row justify-content-center">
|
| 14 |
+
<div class="col-md-8">
|
| 15 |
+
<div class="card">
|
| 16 |
+
<div class="card-header">
|
| 17 |
+
<h2 class="text-center mb-0">Image Caption Generator</h2>
|
| 18 |
+
</div>
|
| 19 |
+
<div class="card-body">
|
| 20 |
+
<form id="uploadForm">
|
| 21 |
+
<div class="mb-4">
|
| 22 |
+
<label class="form-label">Select Model:</label>
|
| 23 |
+
<div class="form-check">
|
| 24 |
+
<input class="form-check-input" type="radio" name="model" id="efficientnet" value="efficientnet" checked>
|
| 25 |
+
<label class="form-check-label" for="efficientnet">
|
| 26 |
+
EfficientNet-B3
|
| 27 |
+
</label>
|
| 28 |
+
</div>
|
| 29 |
+
</div>
|
| 30 |
+
|
| 31 |
+
<div class="mb-4">
|
| 32 |
+
<label class="form-label d-block text-center">Upload Image:</label>
|
| 33 |
+
<div class="upload-container">
|
| 34 |
+
<label class="upload-app">
|
| 35 |
+
<input type="file" id="imageInput" accept="image/png,image/jpeg,image/jpg" required>
|
| 36 |
+
<div class="upload-btn">
|
| 37 |
+
<div class="upload-arrow"></div>
|
| 38 |
+
<div class="upload-success">
|
| 39 |
+
<i class="bi bi-check"></i>
|
| 40 |
+
</div>
|
| 41 |
+
</div>
|
| 42 |
+
</label>
|
| 43 |
+
</div>
|
| 44 |
+
</div>
|
| 45 |
+
|
| 46 |
+
<div class="text-center">
|
| 47 |
+
<button type="submit" class="btn btn-primary" id="submitBtn">
|
| 48 |
+
<span class="spinner-border spinner-border-sm d-none" role="status" aria-hidden="true"></span>
|
| 49 |
+
Generate Caption
|
| 50 |
+
</button>
|
| 51 |
+
</div>
|
| 52 |
+
</form>
|
| 53 |
+
|
| 54 |
+
<div id="resultSection" class="mt-4 d-none">
|
| 55 |
+
<div class="text-center">
|
| 56 |
+
<img id="previewImage" class="img-fluid mb-3 rounded" alt="Uploaded image">
|
| 57 |
+
<div id="captionText" class="alert alert-info"></div>
|
| 58 |
+
</div>
|
| 59 |
+
</div>
|
| 60 |
+
|
| 61 |
+
<div id="errorAlert" class="alert alert-danger mt-3 d-none"></div>
|
| 62 |
+
</div>
|
| 63 |
+
</div>
|
| 64 |
+
</div>
|
| 65 |
+
</div>
|
| 66 |
+
</div>
|
| 67 |
+
|
| 68 |
+
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
|
| 69 |
+
<script src="{{ url_for('static', filename='js/main.js') }}"></script>
|
| 70 |
+
</body>
|
| 71 |
+
</html>
|
training/__pycache__/efficient_train.cpython-314.pyc
ADDED
|
Binary file (25.9 kB). View file
|
|
|
training/__pycache__/resnet_train.cpython-314.pyc
ADDED
|
Binary file (33.4 kB). View file
|
|
|
training/efficient_train.py
ADDED
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from timm import create_model
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
from pycocotools.coco import COCO
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
# Distributed training imports
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 19 |
+
|
| 20 |
+
# ------------------- DDP Setup Functions ------------------- #
|
| 21 |
+
def setup_distributed():
|
| 22 |
+
dist.init_process_group(backend='nccl')
|
| 23 |
+
|
| 24 |
+
def cleanup_distributed():
|
| 25 |
+
dist.destroy_process_group()
|
| 26 |
+
|
| 27 |
+
# ------------------- Configuration and Constants ------------------- #
|
| 28 |
+
DEFAULT_MAX_SEQ_LENGTH = 64
|
| 29 |
+
DEFAULT_EMBED_DIM = 512
|
| 30 |
+
DEFAULT_NUM_LAYERS = 8
|
| 31 |
+
DEFAULT_NUM_HEADS = 8
|
| 32 |
+
|
| 33 |
+
# ------------------- Data Preparation ------------------- #
|
| 34 |
+
class CocoCaptionDataset(Dataset):
|
| 35 |
+
"""Custom COCO dataset that returns image-caption pairs with processing"""
|
| 36 |
+
def __init__(self, root, ann_file, transform=None, max_seq_length=DEFAULT_MAX_SEQ_LENGTH):
|
| 37 |
+
self.coco = COCO(ann_file)
|
| 38 |
+
self.root = root
|
| 39 |
+
self.transform = transform
|
| 40 |
+
self.max_seq_length = max_seq_length
|
| 41 |
+
self.ids = list(self.coco.imgs.keys())
|
| 42 |
+
|
| 43 |
+
# Initialize tokenizer with special tokens
|
| 44 |
+
self.tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
| 45 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 46 |
+
special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
|
| 47 |
+
self.tokenizer.add_special_tokens(special_tokens)
|
| 48 |
+
self.vocab_size = len(self.tokenizer)
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
return len(self.ids)
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx):
|
| 54 |
+
img_id = self.ids[idx]
|
| 55 |
+
img_info = self.coco.loadImgs(img_id)[0]
|
| 56 |
+
img_path = os.path.join(self.root, img_info['file_name'])
|
| 57 |
+
img = Image.open(img_path).convert('RGB')
|
| 58 |
+
|
| 59 |
+
# Get random caption from available annotations
|
| 60 |
+
ann_ids = self.coco.getAnnIds(imgIds=img_id)
|
| 61 |
+
anns = self.coco.loadAnns(ann_ids)
|
| 62 |
+
caption = random.choice(anns)['caption']
|
| 63 |
+
|
| 64 |
+
# Apply transforms
|
| 65 |
+
if self.transform:
|
| 66 |
+
img = self.transform(img)
|
| 67 |
+
|
| 68 |
+
# Tokenize caption with special tokens
|
| 69 |
+
caption = f"<start> {caption} <end>"
|
| 70 |
+
inputs = self.tokenizer(
|
| 71 |
+
caption,
|
| 72 |
+
padding='max_length',
|
| 73 |
+
max_length=self.max_seq_length,
|
| 74 |
+
truncation=True,
|
| 75 |
+
return_tensors='pt',
|
| 76 |
+
)
|
| 77 |
+
return img, inputs.input_ids.squeeze(0)
|
| 78 |
+
|
| 79 |
+
class CocoTestDataset(Dataset):
|
| 80 |
+
"""COCO test dataset that loads images only (no annotations available)"""
|
| 81 |
+
def __init__(self, root, transform=None):
|
| 82 |
+
self.root = root
|
| 83 |
+
self.transform = transform
|
| 84 |
+
# Assumes all files in the directory are images
|
| 85 |
+
self.img_files = sorted(os.listdir(root))
|
| 86 |
+
|
| 87 |
+
def __len__(self):
|
| 88 |
+
return len(self.img_files)
|
| 89 |
+
|
| 90 |
+
def __getitem__(self, idx):
|
| 91 |
+
img_file = self.img_files[idx]
|
| 92 |
+
img_path = os.path.join(self.root, img_file)
|
| 93 |
+
img = Image.open(img_path).convert('RGB')
|
| 94 |
+
if self.transform:
|
| 95 |
+
img = self.transform(img)
|
| 96 |
+
return img, img_file # Return the filename for reference
|
| 97 |
+
|
| 98 |
+
# ------------------- Model Architecture ------------------- #
|
| 99 |
+
class Encoder(nn.Module):
|
| 100 |
+
"""CNN encoder using timm models"""
|
| 101 |
+
def __init__(self, model_name='efficientnet_b3', embed_dim=DEFAULT_EMBED_DIM):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.backbone = create_model(
|
| 104 |
+
model_name,
|
| 105 |
+
pretrained=True,
|
| 106 |
+
num_classes=0,
|
| 107 |
+
global_pool='',
|
| 108 |
+
features_only=False
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Get output channels from backbone
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
dummy = torch.randn(1, 3, 224, 224)
|
| 114 |
+
features = self.backbone(dummy)
|
| 115 |
+
in_features = features.shape[1]
|
| 116 |
+
|
| 117 |
+
self.projection = nn.Linear(in_features, embed_dim)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
features = self.backbone(x) # (batch, channels, height, width)
|
| 121 |
+
batch_size, channels, height, width = features.shape
|
| 122 |
+
features = features.permute(0, 2, 3, 1).reshape(batch_size, -1, channels)
|
| 123 |
+
return self.projection(features)
|
| 124 |
+
|
| 125 |
+
class Decoder(nn.Module):
|
| 126 |
+
"""Transformer decoder with positional embeddings and causal masking"""
|
| 127 |
+
def __init__(self, vocab_size, embed_dim, num_layers, num_heads, max_seq_length, dropout=0.1):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 130 |
+
self.positional_encoding = nn.Embedding(max_seq_length, embed_dim)
|
| 131 |
+
self.dropout = nn.Dropout(dropout)
|
| 132 |
+
|
| 133 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
| 134 |
+
d_model=embed_dim,
|
| 135 |
+
nhead=num_heads,
|
| 136 |
+
dropout=dropout,
|
| 137 |
+
batch_first=False
|
| 138 |
+
)
|
| 139 |
+
self.layers = nn.TransformerDecoder(decoder_layer, num_layers)
|
| 140 |
+
self.fc = nn.Linear(embed_dim, vocab_size)
|
| 141 |
+
self.max_seq_length = max_seq_length
|
| 142 |
+
|
| 143 |
+
# Register causal mask buffer
|
| 144 |
+
self.register_buffer(
|
| 145 |
+
"causal_mask",
|
| 146 |
+
torch.triu(torch.full((max_seq_length, max_seq_length), float('-inf')), diagonal=1)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def forward(self, x, memory, tgt_mask=None):
|
| 150 |
+
seq_length = x.size(1)
|
| 151 |
+
positions = torch.arange(0, seq_length, device=x.device).unsqueeze(0)
|
| 152 |
+
x_emb = self.embedding(x) + self.positional_encoding(positions)
|
| 153 |
+
x_emb = self.dropout(x_emb)
|
| 154 |
+
|
| 155 |
+
# Reshape for transformer: (seq, batch, features)
|
| 156 |
+
x_emb = x_emb.permute(1, 0, 2)
|
| 157 |
+
memory = memory.permute(1, 0, 2)
|
| 158 |
+
|
| 159 |
+
# Apply causal mask
|
| 160 |
+
mask = self.causal_mask[:seq_length, :seq_length]
|
| 161 |
+
output = self.layers(
|
| 162 |
+
x_emb,
|
| 163 |
+
memory,
|
| 164 |
+
tgt_mask=mask
|
| 165 |
+
)
|
| 166 |
+
return self.fc(output.permute(1, 0, 2))
|
| 167 |
+
|
| 168 |
+
class ImageCaptioningModel(nn.Module):
|
| 169 |
+
"""Complete image captioning model"""
|
| 170 |
+
def __init__(self, encoder, decoder):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.encoder = encoder
|
| 173 |
+
self.decoder = decoder
|
| 174 |
+
|
| 175 |
+
def forward(self, images, captions, tgt_mask=None):
|
| 176 |
+
memory = self.encoder(images)
|
| 177 |
+
return self.decoder(captions, memory)
|
| 178 |
+
|
| 179 |
+
# ------------------- Inference Utility ------------------- #
|
| 180 |
+
def generate_caption(model, image, tokenizer, device, max_length=DEFAULT_MAX_SEQ_LENGTH):
|
| 181 |
+
"""
|
| 182 |
+
Generate a caption for a single image using greedy decoding.
|
| 183 |
+
Assumes the tokenizer has '<start>' and '<end>' as special tokens.
|
| 184 |
+
"""
|
| 185 |
+
model.eval()
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
image = image.unsqueeze(0) # shape: (1, 3, H, W)
|
| 188 |
+
if isinstance(model, DDP):
|
| 189 |
+
memory = model.module.encoder(image)
|
| 190 |
+
else:
|
| 191 |
+
memory = model.encoder(image)
|
| 192 |
+
start_token = tokenizer.convert_tokens_to_ids("<start>")
|
| 193 |
+
end_token = tokenizer.convert_tokens_to_ids("<end>")
|
| 194 |
+
caption_ids = [start_token]
|
| 195 |
+
for _ in range(max_length - 1):
|
| 196 |
+
decoder_input = torch.tensor(caption_ids, device=device).unsqueeze(0)
|
| 197 |
+
if isinstance(model, DDP):
|
| 198 |
+
output = model.module.decoder(decoder_input, memory)
|
| 199 |
+
else:
|
| 200 |
+
output = model.decoder(decoder_input, memory)
|
| 201 |
+
next_token_logits = output[0, -1, :]
|
| 202 |
+
next_token = next_token_logits.argmax().item()
|
| 203 |
+
caption_ids.append(next_token)
|
| 204 |
+
if next_token == end_token:
|
| 205 |
+
break
|
| 206 |
+
caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
|
| 207 |
+
return caption_text
|
| 208 |
+
|
| 209 |
+
# ------------------- Training Utilities ------------------- #
|
| 210 |
+
def create_dataloaders(args):
|
| 211 |
+
"""Create train/val/test dataloaders with appropriate transforms"""
|
| 212 |
+
train_transform = transforms.Compose([
|
| 213 |
+
transforms.Resize(256),
|
| 214 |
+
transforms.RandomCrop(224),
|
| 215 |
+
transforms.RandomHorizontalFlip(),
|
| 216 |
+
transforms.ToTensor(),
|
| 217 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 218 |
+
])
|
| 219 |
+
|
| 220 |
+
eval_transform = transforms.Compose([
|
| 221 |
+
transforms.Resize(224),
|
| 222 |
+
transforms.CenterCrop(224),
|
| 223 |
+
transforms.ToTensor(),
|
| 224 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 225 |
+
])
|
| 226 |
+
|
| 227 |
+
# Load datasets
|
| 228 |
+
train_set = CocoCaptionDataset(
|
| 229 |
+
root=args.train_image_dir,
|
| 230 |
+
ann_file=args.train_ann_file,
|
| 231 |
+
transform=train_transform
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
val_set = CocoCaptionDataset(
|
| 235 |
+
root=args.val_image_dir,
|
| 236 |
+
ann_file=args.val_ann_file,
|
| 237 |
+
transform=eval_transform
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
test_set = CocoTestDataset(
|
| 241 |
+
root=args.test_image_dir,
|
| 242 |
+
transform=eval_transform
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# For distributed training, use DistributedSampler
|
| 246 |
+
if args.distributed:
|
| 247 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
|
| 248 |
+
else:
|
| 249 |
+
train_sampler = None
|
| 250 |
+
|
| 251 |
+
# Optimize for GPU: use pin_memory and more workers if CUDA is available
|
| 252 |
+
pin_memory = torch.cuda.is_available()
|
| 253 |
+
num_workers = 8 if torch.cuda.is_available() else 4 # More workers for GPU
|
| 254 |
+
persistent_workers = torch.cuda.is_available() # Keep workers alive between epochs
|
| 255 |
+
|
| 256 |
+
train_loader = DataLoader(
|
| 257 |
+
train_set,
|
| 258 |
+
batch_size=args.batch_size,
|
| 259 |
+
shuffle=(train_sampler is None),
|
| 260 |
+
sampler=train_sampler,
|
| 261 |
+
num_workers=num_workers,
|
| 262 |
+
pin_memory=pin_memory,
|
| 263 |
+
persistent_workers=persistent_workers,
|
| 264 |
+
prefetch_factor=2 if num_workers > 0 else None # Prefetch batches
|
| 265 |
+
)
|
| 266 |
+
val_loader = DataLoader(
|
| 267 |
+
val_set,
|
| 268 |
+
batch_size=args.batch_size,
|
| 269 |
+
shuffle=False,
|
| 270 |
+
num_workers=num_workers,
|
| 271 |
+
pin_memory=pin_memory,
|
| 272 |
+
persistent_workers=persistent_workers
|
| 273 |
+
)
|
| 274 |
+
test_loader = DataLoader(
|
| 275 |
+
test_set,
|
| 276 |
+
batch_size=1, # For inference, process one image at a time
|
| 277 |
+
shuffle=False,
|
| 278 |
+
num_workers=num_workers
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return train_loader, val_loader, test_loader, train_set.tokenizer, train_set
|
| 282 |
+
|
| 283 |
+
def train_epoch(model, loader, optimizer, criterion, scaler, scheduler, device, args):
|
| 284 |
+
model.train()
|
| 285 |
+
total_loss = 0.0
|
| 286 |
+
if args.distributed:
|
| 287 |
+
loader.sampler.set_epoch(args.epoch)
|
| 288 |
+
for batch_idx, (images, captions) in enumerate(loader):
|
| 289 |
+
images = images.to(device)
|
| 290 |
+
captions = captions.to(device)
|
| 291 |
+
|
| 292 |
+
# Teacher forcing: use shifted captions as decoder input
|
| 293 |
+
decoder_input = captions[:, :-1]
|
| 294 |
+
targets = captions[:, 1:].contiguous()
|
| 295 |
+
|
| 296 |
+
optimizer.zero_grad()
|
| 297 |
+
|
| 298 |
+
# Use new API for PyTorch 2.6+
|
| 299 |
+
if hasattr(torch.amp, 'autocast'):
|
| 300 |
+
autocast_context = torch.amp.autocast('cuda', enabled=args.use_amp)
|
| 301 |
+
else:
|
| 302 |
+
autocast_context = torch.cuda.amp.autocast(enabled=args.use_amp)
|
| 303 |
+
|
| 304 |
+
with autocast_context:
|
| 305 |
+
logits = model(images, decoder_input)
|
| 306 |
+
loss = criterion(
|
| 307 |
+
logits.view(-1, logits.size(-1)),
|
| 308 |
+
targets.view(-1)
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
scaler.scale(loss).backward()
|
| 312 |
+
if (batch_idx + 1) % args.grad_accum == 0:
|
| 313 |
+
scaler.step(optimizer)
|
| 314 |
+
scaler.update()
|
| 315 |
+
# Only step scheduler if it's provided and supports per-step updates
|
| 316 |
+
if scheduler is not None:
|
| 317 |
+
scheduler.step() # Update learning rate
|
| 318 |
+
optimizer.zero_grad()
|
| 319 |
+
|
| 320 |
+
total_loss += loss.item()
|
| 321 |
+
|
| 322 |
+
return total_loss / len(loader)
|
| 323 |
+
|
| 324 |
+
def validate(model, loader, criterion, device):
|
| 325 |
+
model.eval()
|
| 326 |
+
total_loss = 0.0
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
for images, captions in loader:
|
| 329 |
+
images = images.to(device)
|
| 330 |
+
captions = captions.to(device)
|
| 331 |
+
decoder_input = captions[:, :-1]
|
| 332 |
+
targets = captions[:, 1:].contiguous()
|
| 333 |
+
|
| 334 |
+
logits = model(images, decoder_input)
|
| 335 |
+
loss = criterion(
|
| 336 |
+
logits.view(-1, logits.size(-1)),
|
| 337 |
+
targets.view(-1)
|
| 338 |
+
)
|
| 339 |
+
total_loss += loss.item()
|
| 340 |
+
|
| 341 |
+
return total_loss / len(loader)
|
| 342 |
+
|
| 343 |
+
def main(args):
|
| 344 |
+
if args.distributed:
|
| 345 |
+
setup_distributed()
|
| 346 |
+
|
| 347 |
+
device = torch.device("cuda", args.local_rank) if args.distributed else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 348 |
+
|
| 349 |
+
torch.manual_seed(args.seed)
|
| 350 |
+
random.seed(args.seed)
|
| 351 |
+
np.random.seed(args.seed)
|
| 352 |
+
|
| 353 |
+
# Create dataloaders and obtain tokenizer and training dataset (for sampler)
|
| 354 |
+
train_loader, val_loader, test_loader, tokenizer, train_set = create_dataloaders(args)
|
| 355 |
+
|
| 356 |
+
# Initialize model
|
| 357 |
+
encoder = Encoder(args.model_name, args.embed_dim)
|
| 358 |
+
decoder = Decoder(
|
| 359 |
+
vocab_size=tokenizer.vocab_size + 2,
|
| 360 |
+
embed_dim=args.embed_dim,
|
| 361 |
+
num_layers=args.num_layers,
|
| 362 |
+
num_heads=args.num_heads,
|
| 363 |
+
max_seq_length=DEFAULT_MAX_SEQ_LENGTH,
|
| 364 |
+
dropout=0.1
|
| 365 |
+
)
|
| 366 |
+
model = ImageCaptioningModel(encoder, decoder).to(device)
|
| 367 |
+
|
| 368 |
+
if args.distributed:
|
| 369 |
+
model = DDP(model, device_ids=[args.local_rank])
|
| 370 |
+
|
| 371 |
+
# Set up training components
|
| 372 |
+
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
|
| 373 |
+
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
|
| 374 |
+
# Use new API for PyTorch 2.6+
|
| 375 |
+
if hasattr(torch.amp, 'GradScaler'):
|
| 376 |
+
scaler = torch.amp.GradScaler('cuda', enabled=args.use_amp)
|
| 377 |
+
else:
|
| 378 |
+
scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
|
| 379 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 380 |
+
optimizer,
|
| 381 |
+
T_max=args.epochs * len(train_loader),
|
| 382 |
+
eta_min=1e-6
|
| 383 |
+
)
|
| 384 |
+
best_val_loss = float('inf')
|
| 385 |
+
patience_counter = 0
|
| 386 |
+
|
| 387 |
+
# Support resume training
|
| 388 |
+
start_epoch = 0
|
| 389 |
+
if args.resume_checkpoint:
|
| 390 |
+
# Handle PyTorch 2.6+ security: allow tokenizer classes
|
| 391 |
+
try:
|
| 392 |
+
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
|
| 393 |
+
torch.serialization.add_safe_globals([GPT2TokenizerFast])
|
| 394 |
+
except ImportError:
|
| 395 |
+
pass
|
| 396 |
+
|
| 397 |
+
# Load checkpoint (weights_only=False for backward compatibility with tokenizer)
|
| 398 |
+
checkpoint = torch.load(args.resume_checkpoint, map_location=device, weights_only=False)
|
| 399 |
+
if args.distributed:
|
| 400 |
+
model.module.load_state_dict(checkpoint['model_state'])
|
| 401 |
+
else:
|
| 402 |
+
model.load_state_dict(checkpoint['model_state'])
|
| 403 |
+
optimizer.load_state_dict(checkpoint['optimizer_state'])
|
| 404 |
+
start_epoch = checkpoint['epoch'] + 1
|
| 405 |
+
best_val_loss = checkpoint.get('val_loss', best_val_loss)
|
| 406 |
+
print(f"Resumed training from epoch {start_epoch}")
|
| 407 |
+
|
| 408 |
+
# Training loop
|
| 409 |
+
for epoch in range(start_epoch, args.epochs):
|
| 410 |
+
args.epoch = epoch # Useful for the sampler in distributed training
|
| 411 |
+
if args.distributed:
|
| 412 |
+
train_loader.sampler.set_epoch(epoch)
|
| 413 |
+
if args.local_rank == 0 or not args.distributed:
|
| 414 |
+
print(f"Epoch {epoch+1}/{args.epochs}")
|
| 415 |
+
train_loss = train_epoch(
|
| 416 |
+
model, train_loader, optimizer, criterion, scaler, scheduler, device, args
|
| 417 |
+
)
|
| 418 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 419 |
+
if args.local_rank == 0 or not args.distributed:
|
| 420 |
+
print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
|
| 421 |
+
|
| 422 |
+
# Checkpointing
|
| 423 |
+
if val_loss < best_val_loss:
|
| 424 |
+
best_val_loss = val_loss
|
| 425 |
+
patience_counter = 0
|
| 426 |
+
torch.save({
|
| 427 |
+
'epoch': epoch,
|
| 428 |
+
'model_state': model.module.state_dict() if args.distributed else model.state_dict(),
|
| 429 |
+
'optimizer_state': optimizer.state_dict(),
|
| 430 |
+
'scheduler_state': scheduler.state_dict(),
|
| 431 |
+
'val_loss': val_loss,
|
| 432 |
+
'tokenizer': tokenizer,
|
| 433 |
+
}, os.path.join(args.checkpoint_dir, 'best_model.pth'))
|
| 434 |
+
else:
|
| 435 |
+
patience_counter += 1
|
| 436 |
+
|
| 437 |
+
if patience_counter >= args.early_stopping_patience:
|
| 438 |
+
print("Early stopping triggered")
|
| 439 |
+
break
|
| 440 |
+
|
| 441 |
+
# Inference on test set
|
| 442 |
+
if args.local_rank == 0 or not args.distributed:
|
| 443 |
+
print("\nGenerating captions on test set images:")
|
| 444 |
+
model.eval()
|
| 445 |
+
for idx, (image, filename) in enumerate(test_loader):
|
| 446 |
+
image = image.to(device).squeeze(0)
|
| 447 |
+
caption = generate_caption(model, image, tokenizer, device)
|
| 448 |
+
print(f"{filename}: {caption}")
|
| 449 |
+
if idx >= 4:
|
| 450 |
+
break
|
| 451 |
+
|
| 452 |
+
if args.distributed:
|
| 453 |
+
cleanup_distributed()
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
if __name__ == "__main__":
|
| 457 |
+
parser = argparse.ArgumentParser()
|
| 458 |
+
# Data arguments
|
| 459 |
+
parser.add_argument('--train_image_dir', type=str, required=True)
|
| 460 |
+
parser.add_argument('--train_ann_file', type=str, required=True)
|
| 461 |
+
parser.add_argument('--val_image_dir', type=str, required=True)
|
| 462 |
+
parser.add_argument('--val_ann_file', type=str, required=True)
|
| 463 |
+
parser.add_argument('--test_image_dir', type=str, required=True) # Test set images only
|
| 464 |
+
|
| 465 |
+
# Model arguments
|
| 466 |
+
parser.add_argument('--model_name', type=str, default='efficientnet_b3')
|
| 467 |
+
parser.add_argument('--embed_dim', type=int, default=DEFAULT_EMBED_DIM)
|
| 468 |
+
parser.add_argument('--num_layers', type=int, default=DEFAULT_NUM_LAYERS)
|
| 469 |
+
parser.add_argument('--num_heads', type=int, default=DEFAULT_NUM_HEADS)
|
| 470 |
+
|
| 471 |
+
# Training arguments
|
| 472 |
+
parser.add_argument('--batch_size', type=int, default=96)
|
| 473 |
+
parser.add_argument('--lr', type=float, default=3e-4)
|
| 474 |
+
parser.add_argument('--epochs', type=int, default=10)
|
| 475 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 476 |
+
parser.add_argument('--use_amp', action='store_true')
|
| 477 |
+
parser.add_argument('--grad_accum', type=int, default=1)
|
| 478 |
+
parser.add_argument('--checkpoint_dir', type=str, default='/workspace')
|
| 479 |
+
parser.add_argument('--early_stopping_patience', type=int, default=3)
|
| 480 |
+
|
| 481 |
+
# Distributed training arguments
|
| 482 |
+
# Accept both --local_rank and --local-rank
|
| 483 |
+
parser.add_argument('--local_rank', '--local-rank', type=int, default=0,
|
| 484 |
+
help="Local rank. Necessary for using distributed training.")
|
| 485 |
+
parser.add_argument('--distributed', action='store_true', help="Use distributed training")
|
| 486 |
+
|
| 487 |
+
# Resume training argument
|
| 488 |
+
parser.add_argument('--resume_checkpoint', type=str, default=None, help="Path to checkpoint to resume training from.")
|
| 489 |
+
|
| 490 |
+
args = parser.parse_args()
|
| 491 |
+
|
| 492 |
+
# Override local_rank from environment variable if set
|
| 493 |
+
if "LOCAL_RANK" in os.environ:
|
| 494 |
+
args.local_rank = int(os.environ["LOCAL_RANK"])
|
| 495 |
+
|
| 496 |
+
# Create checkpoint directory
|
| 497 |
+
os.makedirs(args.checkpoint_dir, exist_ok=True)
|
| 498 |
+
|
| 499 |
+
main(args)
|
training/hyperparameter_tuning.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hyperparameter Optimization using Optuna
|
| 3 |
+
Run this to find the best hyperparameters for your model
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import optuna
|
| 7 |
+
import torch
|
| 8 |
+
import argparse
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
from efficient_train import create_dataloaders, Encoder, Decoder, ImageCaptioningModel
|
| 12 |
+
from efficient_train import train_epoch, validate, generate_caption
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
|
| 16 |
+
|
| 17 |
+
def train_with_config(trial, args):
|
| 18 |
+
"""Train model with suggested hyperparameters from Optuna"""
|
| 19 |
+
|
| 20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
+
# Suggest hyperparameters
|
| 23 |
+
lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
|
| 24 |
+
batch_size = trial.suggest_categorical('batch_size', [32, 64, 96, 128])
|
| 25 |
+
embed_dim = trial.suggest_categorical('embed_dim', [256, 512, 768])
|
| 26 |
+
num_layers = trial.suggest_int('num_layers', 4, 12)
|
| 27 |
+
num_heads = trial.suggest_categorical('num_heads', [4, 8, 12, 16])
|
| 28 |
+
dropout = trial.suggest_uniform('dropout', 0.1, 0.5)
|
| 29 |
+
weight_decay = trial.suggest_loguniform('weight_decay', 1e-5, 1e-2)
|
| 30 |
+
warmup_epochs = trial.suggest_int('warmup_epochs', 0, 3)
|
| 31 |
+
|
| 32 |
+
# Update args with suggested values
|
| 33 |
+
args.lr = lr
|
| 34 |
+
args.batch_size = batch_size
|
| 35 |
+
args.embed_dim = embed_dim
|
| 36 |
+
args.num_layers = num_layers
|
| 37 |
+
args.num_heads = num_heads
|
| 38 |
+
args.epochs = 5 # Fewer epochs for hyperparameter search
|
| 39 |
+
|
| 40 |
+
# Create dataloaders
|
| 41 |
+
train_loader, val_loader, test_loader, tokenizer, train_set = create_dataloaders(args)
|
| 42 |
+
|
| 43 |
+
# Initialize model
|
| 44 |
+
encoder = Encoder(args.model_name, embed_dim)
|
| 45 |
+
decoder = Decoder(
|
| 46 |
+
vocab_size=tokenizer.vocab_size + 2,
|
| 47 |
+
embed_dim=embed_dim,
|
| 48 |
+
num_layers=num_layers,
|
| 49 |
+
num_heads=num_heads,
|
| 50 |
+
max_seq_length=64,
|
| 51 |
+
dropout=dropout
|
| 52 |
+
)
|
| 53 |
+
model = ImageCaptioningModel(encoder, decoder).to(device)
|
| 54 |
+
|
| 55 |
+
# Optimizer
|
| 56 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 57 |
+
|
| 58 |
+
# Scheduler
|
| 59 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
|
| 60 |
+
|
| 61 |
+
# Loss
|
| 62 |
+
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
|
| 63 |
+
|
| 64 |
+
# Mixed precision
|
| 65 |
+
scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
|
| 66 |
+
|
| 67 |
+
# Training loop (fewer epochs for hyperparameter search)
|
| 68 |
+
best_val_loss = float('inf')
|
| 69 |
+
|
| 70 |
+
for epoch in range(args.epochs):
|
| 71 |
+
# Train
|
| 72 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler,
|
| 73 |
+
scheduler, device, args)
|
| 74 |
+
|
| 75 |
+
# Validate
|
| 76 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 77 |
+
|
| 78 |
+
# Update scheduler
|
| 79 |
+
scheduler.step(val_loss)
|
| 80 |
+
|
| 81 |
+
# Report to Optuna
|
| 82 |
+
trial.report(val_loss, epoch)
|
| 83 |
+
|
| 84 |
+
# Prune trial if not promising
|
| 85 |
+
if trial.should_prune():
|
| 86 |
+
raise optuna.exceptions.TrialPruned()
|
| 87 |
+
|
| 88 |
+
if val_loss < best_val_loss:
|
| 89 |
+
best_val_loss = val_loss
|
| 90 |
+
|
| 91 |
+
return best_val_loss
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def objective(trial):
|
| 95 |
+
"""Optuna objective function"""
|
| 96 |
+
|
| 97 |
+
# Create minimal args object
|
| 98 |
+
args = argparse.Namespace(
|
| 99 |
+
train_image_dir='Data/train2017/train2017',
|
| 100 |
+
train_ann_file='Data/annotations_trainval2017/annotations/captions_train2017.json',
|
| 101 |
+
val_image_dir='Data/val2017',
|
| 102 |
+
val_ann_file='Data/annotations_trainval2017/annotations/captions_val2017.json',
|
| 103 |
+
test_image_dir='Data/test2017/test2017',
|
| 104 |
+
model_name='efficientnet_b3',
|
| 105 |
+
embed_dim=512, # Will be overridden
|
| 106 |
+
num_layers=8, # Will be overridden
|
| 107 |
+
num_heads=8, # Will be overridden
|
| 108 |
+
batch_size=96, # Will be overridden
|
| 109 |
+
lr=3e-4, # Will be overridden
|
| 110 |
+
epochs=5,
|
| 111 |
+
seed=42,
|
| 112 |
+
use_amp=True,
|
| 113 |
+
grad_accum=1,
|
| 114 |
+
checkpoint_dir='checkpoints',
|
| 115 |
+
early_stopping_patience=3,
|
| 116 |
+
distributed=False,
|
| 117 |
+
local_rank=0,
|
| 118 |
+
resume_checkpoint=None
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
val_loss = train_with_config(trial, args)
|
| 123 |
+
return val_loss
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"Trial failed: {e}")
|
| 126 |
+
return float('inf')
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def main():
|
| 130 |
+
parser = argparse.ArgumentParser(description='Hyperparameter optimization with Optuna')
|
| 131 |
+
parser.add_argument('--n_trials', type=int, default=50, help='Number of trials')
|
| 132 |
+
parser.add_argument('--timeout', type=int, default=3600*24, help='Timeout in seconds')
|
| 133 |
+
parser.add_argument('--study_name', type=str, default='efficientnet_captioning',
|
| 134 |
+
help='Study name')
|
| 135 |
+
parser.add_argument('--storage', type=str, default='sqlite:///optuna_study.db',
|
| 136 |
+
help='Storage URL for study')
|
| 137 |
+
|
| 138 |
+
args = parser.parse_args()
|
| 139 |
+
|
| 140 |
+
# Create or load study
|
| 141 |
+
study = optuna.create_study(
|
| 142 |
+
direction='minimize',
|
| 143 |
+
study_name=args.study_name,
|
| 144 |
+
storage=args.storage,
|
| 145 |
+
load_if_exists=True,
|
| 146 |
+
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=3)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
print(f"Starting optimization with {args.n_trials} trials...")
|
| 150 |
+
print(f"Study: {args.study_name}")
|
| 151 |
+
|
| 152 |
+
# Optimize
|
| 153 |
+
study.optimize(objective, n_trials=args.n_trials, timeout=args.timeout)
|
| 154 |
+
|
| 155 |
+
# Print results
|
| 156 |
+
print("\n" + "="*60)
|
| 157 |
+
print("Optimization Complete!")
|
| 158 |
+
print("="*60)
|
| 159 |
+
print(f"Best trial: {study.best_trial.number}")
|
| 160 |
+
print(f"Best validation loss: {study.best_value:.4f}")
|
| 161 |
+
print("\nBest parameters:")
|
| 162 |
+
for key, value in study.best_params.items():
|
| 163 |
+
print(f" {key}: {value}")
|
| 164 |
+
|
| 165 |
+
# Save results
|
| 166 |
+
import json
|
| 167 |
+
with open('best_hyperparameters.json', 'w') as f:
|
| 168 |
+
json.dump(study.best_params, f, indent=2)
|
| 169 |
+
|
| 170 |
+
print("\nBest hyperparameters saved to best_hyperparameters.json")
|
| 171 |
+
|
| 172 |
+
# Visualize (optional, requires plotly)
|
| 173 |
+
try:
|
| 174 |
+
import optuna.visualization as vis
|
| 175 |
+
|
| 176 |
+
# Optimization history
|
| 177 |
+
fig = vis.plot_optimization_history(study)
|
| 178 |
+
fig.write_image("optimization_history.png")
|
| 179 |
+
print("Saved optimization_history.png")
|
| 180 |
+
|
| 181 |
+
# Parameter importances
|
| 182 |
+
fig = vis.plot_param_importances(study)
|
| 183 |
+
fig.write_image("param_importances.png")
|
| 184 |
+
print("Saved param_importances.png")
|
| 185 |
+
|
| 186 |
+
# Parallel coordinate plot
|
| 187 |
+
fig = vis.plot_parallel_coordinate(study)
|
| 188 |
+
fig.write_image("parallel_coordinate.png")
|
| 189 |
+
print("Saved parallel_coordinate.png")
|
| 190 |
+
|
| 191 |
+
except ImportError:
|
| 192 |
+
print("Install plotly to generate visualizations: pip install plotly")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == '__main__':
|
| 196 |
+
main()
|
| 197 |
+
|
training/resnet_train.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import nltk
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from pycocotools.coco import COCO
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
from collections import Counter
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
from torchvision import models
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 21 |
+
from nltk.translate.meteor_score import meteor_score
|
| 22 |
+
|
| 23 |
+
# Additional imports for extended metrics
|
| 24 |
+
from rouge import Rouge
|
| 25 |
+
from pycocoevalcap.cider.cider import Cider
|
| 26 |
+
|
| 27 |
+
nltk.download('punkt', quiet=True)
|
| 28 |
+
nltk.download('punkt_tab', quiet=True)
|
| 29 |
+
nltk.download('wordnet', quiet=True)
|
| 30 |
+
|
| 31 |
+
# ===========================
|
| 32 |
+
# CONFIGURATION
|
| 33 |
+
# ===========================
|
| 34 |
+
CONFIG = {
|
| 35 |
+
# Paths
|
| 36 |
+
"train_ann": r"B:/!S3/Computer Vision/Project/annotations/captions_train2017.json",
|
| 37 |
+
"val_ann": r"B:/!S3/Computer Vision/Project/annotations/captions_val2017.json",
|
| 38 |
+
"train_img_dir": "images/train2017",
|
| 39 |
+
"val_img_dir": "images/val2017",
|
| 40 |
+
|
| 41 |
+
# Model
|
| 42 |
+
"img_size": 224,
|
| 43 |
+
"embed_size": 256,
|
| 44 |
+
"hidden_size": 512,
|
| 45 |
+
"attention_dim": 512,
|
| 46 |
+
"feature_map_size": 14, # From ResNet feature maps
|
| 47 |
+
"dropout": 0.5, # Dropout probability added
|
| 48 |
+
|
| 49 |
+
# Training
|
| 50 |
+
"batch_size": 176,
|
| 51 |
+
"num_epochs": 30,
|
| 52 |
+
"lr": 0.005,
|
| 53 |
+
"fine_tune_encoder": True,
|
| 54 |
+
"grad_clip": 5.0,
|
| 55 |
+
|
| 56 |
+
# Vocabulary
|
| 57 |
+
"vocab_threshold": 5,
|
| 58 |
+
"max_len": 20,
|
| 59 |
+
|
| 60 |
+
# Beam search
|
| 61 |
+
"beam_size": 3
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# ===========================
|
| 65 |
+
# Vocabulary Builder
|
| 66 |
+
# ===========================
|
| 67 |
+
class Vocabulary:
|
| 68 |
+
def __init__(self):
|
| 69 |
+
self.word2idx = {}
|
| 70 |
+
self.idx2word = {}
|
| 71 |
+
self.idx = 0
|
| 72 |
+
|
| 73 |
+
def build(self, coco, threshold):
|
| 74 |
+
counter = Counter()
|
| 75 |
+
ids = list(coco.anns.keys())
|
| 76 |
+
for ann_id in tqdm(ids):
|
| 77 |
+
caption = coco.anns[ann_id]['caption']
|
| 78 |
+
tokens = nltk.word_tokenize(caption.lower())
|
| 79 |
+
counter.update(tokens)
|
| 80 |
+
# Add special tokens
|
| 81 |
+
self.add_word('<pad>')
|
| 82 |
+
self.add_word('<start>')
|
| 83 |
+
self.add_word('<end>')
|
| 84 |
+
self.add_word('<unk>')
|
| 85 |
+
# Add words meeting threshold
|
| 86 |
+
for word, cnt in counter.items():
|
| 87 |
+
if cnt >= threshold:
|
| 88 |
+
self.add_word(word)
|
| 89 |
+
|
| 90 |
+
def add_word(self, word):
|
| 91 |
+
if word not in self.word2idx:
|
| 92 |
+
self.word2idx[word] = self.idx
|
| 93 |
+
self.idx2word[self.idx] = word
|
| 94 |
+
self.idx += 1
|
| 95 |
+
|
| 96 |
+
# Initialize vocab with full training data (only if training data exists)
|
| 97 |
+
# This allows the module to be imported for inference without training data
|
| 98 |
+
vocab = Vocabulary()
|
| 99 |
+
# Always add special tokens (needed for DecoderRNN class definition)
|
| 100 |
+
vocab.add_word('<pad>')
|
| 101 |
+
vocab.add_word('<start>')
|
| 102 |
+
vocab.add_word('<end>')
|
| 103 |
+
vocab.add_word('<unk>')
|
| 104 |
+
|
| 105 |
+
if os.path.exists(CONFIG['train_ann']):
|
| 106 |
+
try:
|
| 107 |
+
coco_train = COCO(CONFIG['train_ann'])
|
| 108 |
+
vocab.build(coco_train, CONFIG['vocab_threshold'])
|
| 109 |
+
print(f"Vocabulary size: {len(vocab.word2idx)}")
|
| 110 |
+
except (FileNotFoundError, OSError) as e:
|
| 111 |
+
# Training data not available - vocab will be loaded from checkpoint
|
| 112 |
+
# Keep minimal vocab with special tokens for class definition
|
| 113 |
+
print(f"Warning: Could not load training data. Vocabulary will be loaded from checkpoint.")
|
| 114 |
+
else:
|
| 115 |
+
# Training data path doesn't exist - keep minimal vocab for inference
|
| 116 |
+
print(f"Warning: Training data not found at {CONFIG['train_ann']}. Vocabulary will be loaded from checkpoint.")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ===========================
|
| 120 |
+
# Attention-based Model
|
| 121 |
+
# ===========================
|
| 122 |
+
class EncoderCNN(nn.Module):
|
| 123 |
+
def __init__(self):
|
| 124 |
+
super().__init__()
|
| 125 |
+
# Use the new weights parameter instead of the deprecated 'pretrained'
|
| 126 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
| 127 |
+
weights = ResNet50_Weights.IMAGENET1K_V1
|
| 128 |
+
resnet = resnet50(weights=weights)
|
| 129 |
+
modules = list(resnet.children())[:-2]
|
| 130 |
+
self.cnn = nn.Sequential(*modules)
|
| 131 |
+
self.adaptive_pool = nn.AdaptiveAvgPool2d((CONFIG['feature_map_size'], CONFIG['feature_map_size']))
|
| 132 |
+
if not CONFIG['fine_tune_encoder']:
|
| 133 |
+
for param in self.cnn.parameters():
|
| 134 |
+
param.requires_grad = False
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
features = self.cnn(x) # (batch, 2048, H, W)
|
| 138 |
+
features = self.adaptive_pool(features) # (batch, 2048, 14, 14)
|
| 139 |
+
features = features.permute(0, 2, 3, 1) # (batch, 14, 14, 2048)
|
| 140 |
+
features = features.view(features.size(0), -1, features.size(-1)) # (batch, 196, 2048)
|
| 141 |
+
return features
|
| 142 |
+
|
| 143 |
+
class Attention(nn.Module):
|
| 144 |
+
def __init__(self):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.U = nn.Linear(CONFIG['hidden_size'], CONFIG['attention_dim'])
|
| 147 |
+
self.W = nn.Linear(2048, CONFIG['attention_dim'])
|
| 148 |
+
self.v = nn.Linear(CONFIG['attention_dim'], 1)
|
| 149 |
+
self.tanh = nn.Tanh()
|
| 150 |
+
self.softmax = nn.Softmax(dim=1)
|
| 151 |
+
|
| 152 |
+
def forward(self, features, hidden):
|
| 153 |
+
U_h = self.U(hidden).unsqueeze(1) # (batch, 1, attention_dim)
|
| 154 |
+
W_s = self.W(features) # (batch, 196, attention_dim)
|
| 155 |
+
att = self.tanh(W_s + U_h) # (batch, 196, attention_dim)
|
| 156 |
+
e = self.v(att).squeeze(2) # (batch, 196)
|
| 157 |
+
alpha = self.softmax(e) # (batch, 196)
|
| 158 |
+
context = (features * alpha.unsqueeze(2)).sum(dim=1) # (batch, 2048)
|
| 159 |
+
return context, alpha
|
| 160 |
+
|
| 161 |
+
class DecoderRNN(nn.Module):
|
| 162 |
+
def __init__(self):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.embed = nn.Embedding(len(vocab.word2idx), CONFIG['embed_size'])
|
| 165 |
+
self.lstm = nn.LSTM(CONFIG['embed_size'] + 2048,
|
| 166 |
+
CONFIG['hidden_size'], batch_first=True)
|
| 167 |
+
self.attention = Attention()
|
| 168 |
+
self.fc = nn.Linear(CONFIG['hidden_size'], len(vocab.word2idx))
|
| 169 |
+
self.dropout = nn.Dropout(p=CONFIG['dropout'])
|
| 170 |
+
|
| 171 |
+
def forward(self, features, captions, teacher_forcing_ratio=0.5):
|
| 172 |
+
batch_size = features.size(0)
|
| 173 |
+
h, c = self.init_hidden(features)
|
| 174 |
+
seq_length = captions.size(1) - 1
|
| 175 |
+
outputs = torch.zeros(batch_size, seq_length, len(vocab.word2idx)).to(features.device)
|
| 176 |
+
embeddings = self.dropout(self.embed(captions[:, 0]))
|
| 177 |
+
for t in range(seq_length):
|
| 178 |
+
context, alpha = self.attention(features, h.squeeze(0))
|
| 179 |
+
lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
|
| 180 |
+
out, (h, c) = self.lstm(lstm_input, (h, c))
|
| 181 |
+
out = self.dropout(out)
|
| 182 |
+
output = self.fc(out.squeeze(1))
|
| 183 |
+
outputs[:, t] = output
|
| 184 |
+
use_teacher_forcing = np.random.random() < teacher_forcing_ratio
|
| 185 |
+
if use_teacher_forcing and t < seq_length - 1:
|
| 186 |
+
embeddings = self.dropout(self.embed(captions[:, t+1]))
|
| 187 |
+
else:
|
| 188 |
+
embeddings = self.dropout(self.embed(output.argmax(dim=-1)))
|
| 189 |
+
return outputs
|
| 190 |
+
|
| 191 |
+
def init_hidden(self, features):
|
| 192 |
+
h = torch.zeros(1, features.size(0), CONFIG['hidden_size']).to(features.device)
|
| 193 |
+
c = torch.zeros(1, features.size(0), CONFIG['hidden_size']).to(features.device)
|
| 194 |
+
return h, c
|
| 195 |
+
|
| 196 |
+
# ===========================
|
| 197 |
+
# Enhanced Dataset Class
|
| 198 |
+
# ===========================
|
| 199 |
+
class CocoDataset(Dataset):
|
| 200 |
+
def __init__(self, ann_file, img_dir, vocab, transform=None):
|
| 201 |
+
self.coco = COCO(ann_file)
|
| 202 |
+
self.img_dir = img_dir
|
| 203 |
+
self.vocab = vocab
|
| 204 |
+
self.transform = transform or self.default_transform()
|
| 205 |
+
all_ids = list(self.coco.anns.keys())
|
| 206 |
+
valid_ids = []
|
| 207 |
+
for ann_id in all_ids:
|
| 208 |
+
ann = self.coco.anns[ann_id]
|
| 209 |
+
img_id = ann['image_id']
|
| 210 |
+
file_name = self.coco.loadImgs(img_id)[0]['file_name']
|
| 211 |
+
img_path = os.path.join(self.img_dir, file_name)
|
| 212 |
+
if os.path.exists(img_path):
|
| 213 |
+
valid_ids.append(ann_id)
|
| 214 |
+
else:
|
| 215 |
+
print(f"Warning: File {img_path} not found. Skipping annotation id {ann_id}.")
|
| 216 |
+
self.ids = valid_ids
|
| 217 |
+
|
| 218 |
+
def default_transform(self):
|
| 219 |
+
return transforms.Compose([
|
| 220 |
+
transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
|
| 221 |
+
transforms.ToTensor(),
|
| 222 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 223 |
+
std=[0.229, 0.224, 0.225])
|
| 224 |
+
])
|
| 225 |
+
|
| 226 |
+
def __len__(self):
|
| 227 |
+
return len(self.ids)
|
| 228 |
+
|
| 229 |
+
def __getitem__(self, idx):
|
| 230 |
+
ann_id = self.ids[idx]
|
| 231 |
+
ann = self.coco.anns[ann_id]
|
| 232 |
+
img_id = ann['image_id']
|
| 233 |
+
img_info = self.coco.loadImgs(img_id)[0]
|
| 234 |
+
img_path = os.path.join(self.img_dir, img_info['file_name'])
|
| 235 |
+
img = Image.open(img_path).convert('RGB')
|
| 236 |
+
img = self.transform(img)
|
| 237 |
+
caption = ann['caption']
|
| 238 |
+
tokens = ['<start>'] + nltk.word_tokenize(caption.lower()) + ['<end>']
|
| 239 |
+
caption_ids = [self.vocab.word2idx.get(token, self.vocab.word2idx['<unk>']) for token in tokens]
|
| 240 |
+
caption_ids += [self.vocab.word2idx['<pad>']] * (CONFIG['max_len'] - len(caption_ids))
|
| 241 |
+
caption_ids = caption_ids[:CONFIG['max_len']]
|
| 242 |
+
return img, torch.tensor(caption_ids)
|
| 243 |
+
|
| 244 |
+
# ===========================
|
| 245 |
+
# Distributed Setup Functions
|
| 246 |
+
# ===========================
|
| 247 |
+
def setup_distributed():
|
| 248 |
+
dist.init_process_group(backend='nccl')
|
| 249 |
+
|
| 250 |
+
def cleanup_distributed():
|
| 251 |
+
dist.destroy_process_group()
|
| 252 |
+
|
| 253 |
+
# ===========================
|
| 254 |
+
# Training & Evaluation
|
| 255 |
+
# ===========================
|
| 256 |
+
def evaluate(encoder, decoder, loader, device, criterion, compute_extended=False):
|
| 257 |
+
encoder.eval()
|
| 258 |
+
decoder.eval()
|
| 259 |
+
total_loss = 0
|
| 260 |
+
# Instantiate smoothing function for BLEU score.
|
| 261 |
+
smoothing_fn = SmoothingFunction().method1
|
| 262 |
+
if compute_extended:
|
| 263 |
+
bleu_scores = []
|
| 264 |
+
meteor_scores = []
|
| 265 |
+
rouge = Rouge()
|
| 266 |
+
rouge1_scores = []
|
| 267 |
+
rougeL_scores = []
|
| 268 |
+
cider_scorer = Cider()
|
| 269 |
+
ref_dict = {}
|
| 270 |
+
hyp_dict = {}
|
| 271 |
+
sample_id = 0
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
for imgs, caps in loader:
|
| 274 |
+
imgs = imgs.to(device)
|
| 275 |
+
caps = caps.to(device)
|
| 276 |
+
features = encoder(imgs)
|
| 277 |
+
outputs = decoder(features, caps, teacher_forcing_ratio=0)
|
| 278 |
+
loss = criterion(outputs.view(-1, len(vocab.word2idx)), caps[:, 1:].reshape(-1))
|
| 279 |
+
total_loss += loss.item()
|
| 280 |
+
for i in range(imgs.size(0)):
|
| 281 |
+
predicted_ids = beam_search(features[i].unsqueeze(0), decoder, device)
|
| 282 |
+
predicted_caption = [vocab.idx2word[idx] for idx in predicted_ids
|
| 283 |
+
if idx not in [vocab.word2idx['<start>'], vocab.word2idx['<end>'], vocab.word2idx['<pad>']]]
|
| 284 |
+
reference_ids = caps[i].tolist()
|
| 285 |
+
reference_caption = [vocab.idx2word[idx] for idx in reference_ids
|
| 286 |
+
if idx not in [vocab.word2idx['<start>'], vocab.word2idx['<end>'], vocab.word2idx['<pad>']]]
|
| 287 |
+
bleu = sentence_bleu([reference_caption], predicted_caption, smoothing_function=smoothing_fn)
|
| 288 |
+
bleu_scores.append(bleu)
|
| 289 |
+
meteor = meteor_score([reference_caption], predicted_caption)
|
| 290 |
+
meteor_scores.append(meteor)
|
| 291 |
+
pred_str = " ".join(predicted_caption)
|
| 292 |
+
ref_str = " ".join(reference_caption)
|
| 293 |
+
rouge_scores = rouge.get_scores(pred_str, ref_str)
|
| 294 |
+
rouge1_scores.append(rouge_scores[0]['rouge-1']['f'])
|
| 295 |
+
rougeL_scores.append(rouge_scores[0]['rouge-l']['f'])
|
| 296 |
+
ref_dict[sample_id] = [ref_str]
|
| 297 |
+
hyp_dict[sample_id] = [pred_str]
|
| 298 |
+
sample_id += 1
|
| 299 |
+
avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
|
| 300 |
+
avg_meteor = sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0
|
| 301 |
+
avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores) if rouge1_scores else 0
|
| 302 |
+
avg_rougeL = sum(rougeL_scores) / len(rougeL_scores) if rougeL_scores else 0
|
| 303 |
+
cider_score, _ = cider_scorer.compute_score(ref_dict, hyp_dict)
|
| 304 |
+
metrics = {'BLEU': avg_bleu, 'METEOR': avg_meteor,
|
| 305 |
+
'ROUGE-1': avg_rouge1, 'ROUGE-L': avg_rougeL, 'CIDEr': cider_score}
|
| 306 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
| 307 |
+
print(f"Extended Metrics: {metrics}")
|
| 308 |
+
return total_loss / len(loader), metrics
|
| 309 |
+
else:
|
| 310 |
+
with torch.no_grad():
|
| 311 |
+
for imgs, caps in loader:
|
| 312 |
+
imgs = imgs.to(device)
|
| 313 |
+
caps = caps.to(device)
|
| 314 |
+
features = encoder(imgs)
|
| 315 |
+
outputs = decoder(features, caps, teacher_forcing_ratio=0)
|
| 316 |
+
loss = criterion(outputs.view(-1, len(vocab.word2idx)), caps[:, 1:].reshape(-1))
|
| 317 |
+
total_loss += loss.item()
|
| 318 |
+
return total_loss / len(loader)
|
| 319 |
+
|
| 320 |
+
def beam_search(features, decoder, device):
|
| 321 |
+
k = CONFIG['beam_size']
|
| 322 |
+
start_token = vocab.word2idx['<start>']
|
| 323 |
+
h, c = (decoder.module.init_hidden(features) if isinstance(decoder, torch.nn.parallel.DistributedDataParallel)
|
| 324 |
+
else decoder.init_hidden(features))
|
| 325 |
+
sequences = [[[start_token], 0.0, h, c]]
|
| 326 |
+
for _ in range(CONFIG['max_len'] - 1):
|
| 327 |
+
all_candidates = []
|
| 328 |
+
for seq in sequences:
|
| 329 |
+
tokens, score, h, c = seq
|
| 330 |
+
if tokens[-1] == vocab.word2idx['<end>']:
|
| 331 |
+
all_candidates.append(seq)
|
| 332 |
+
continue
|
| 333 |
+
input_tensor = torch.LongTensor([tokens[-1]]).to(device)
|
| 334 |
+
if isinstance(decoder, torch.nn.parallel.DistributedDataParallel):
|
| 335 |
+
context, _ = decoder.module.attention(features, h.squeeze(0))
|
| 336 |
+
emb = decoder.module.embed(input_tensor)
|
| 337 |
+
lstm_input = torch.cat([emb, context], dim=1).unsqueeze(1)
|
| 338 |
+
out, (h, c) = decoder.module.lstm(lstm_input, (h, c))
|
| 339 |
+
output = decoder.module.fc(out.squeeze(1))
|
| 340 |
+
else:
|
| 341 |
+
context, _ = decoder.attention(features, h.squeeze(0))
|
| 342 |
+
emb = decoder.embed(input_tensor)
|
| 343 |
+
lstm_input = torch.cat([emb, context], dim=1).unsqueeze(1)
|
| 344 |
+
out, (h, c) = decoder.lstm(lstm_input, (h, c))
|
| 345 |
+
output = decoder.fc(out.squeeze(1))
|
| 346 |
+
log_probs = torch.log_softmax(output, dim=1)
|
| 347 |
+
top_probs, top_indices = log_probs.topk(k)
|
| 348 |
+
for i in range(k):
|
| 349 |
+
token = top_indices[0][i].item()
|
| 350 |
+
new_score = score + top_probs[0][i].item()
|
| 351 |
+
new_seq = tokens + [token]
|
| 352 |
+
all_candidates.append([new_seq, new_score, h, c])
|
| 353 |
+
ordered = sorted(all_candidates, key=lambda x: x[1] / len(x[0]), reverse=True)
|
| 354 |
+
sequences = ordered[:k]
|
| 355 |
+
return sequences[0][0]
|
| 356 |
+
|
| 357 |
+
def visualize_attention(image_path, encoder, decoder, device):
|
| 358 |
+
img = Image.open(image_path).convert('RGB')
|
| 359 |
+
transform = transforms.Compose([
|
| 360 |
+
transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
|
| 361 |
+
transforms.ToTensor(),
|
| 362 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 363 |
+
])
|
| 364 |
+
img_tensor = transform(img).unsqueeze(0).to(device)
|
| 365 |
+
encoder.eval()
|
| 366 |
+
decoder.eval()
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
features = encoder(img_tensor)
|
| 369 |
+
caption_ids = beam_search(features, decoder, device)
|
| 370 |
+
caption = [vocab.idx2word[idx] for idx in caption_ids
|
| 371 |
+
if idx not in [vocab.word2idx['<start>'], vocab.word2idx['<end>'], vocab.word2idx['<pad>']]]
|
| 372 |
+
return ' '.join(caption)
|
| 373 |
+
|
| 374 |
+
def train(distributed=False, local_rank=0, device=torch.device('cpu'), resume_checkpoint=None):
|
| 375 |
+
train_set = CocoDataset(CONFIG['train_ann'], CONFIG['train_img_dir'], vocab)
|
| 376 |
+
val_set = CocoDataset(CONFIG['val_ann'], CONFIG['val_img_dir'], vocab)
|
| 377 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) if distributed else None
|
| 378 |
+
val_sampler = torch.utils.data.distributed.DistributedSampler(val_set, shuffle=False) if distributed else None
|
| 379 |
+
train_loader = DataLoader(train_set,
|
| 380 |
+
batch_size=CONFIG['batch_size'],
|
| 381 |
+
shuffle=(train_sampler is None),
|
| 382 |
+
sampler=train_sampler,
|
| 383 |
+
num_workers=8)
|
| 384 |
+
val_loader = DataLoader(val_set,
|
| 385 |
+
batch_size=CONFIG['batch_size'],
|
| 386 |
+
sampler=val_sampler,
|
| 387 |
+
num_workers=8)
|
| 388 |
+
encoder = EncoderCNN().to(device)
|
| 389 |
+
decoder = DecoderRNN().to(device)
|
| 390 |
+
if distributed:
|
| 391 |
+
encoder = torch.nn.parallel.DistributedDataParallel(encoder, device_ids=[local_rank], output_device=local_rank)
|
| 392 |
+
decoder = torch.nn.parallel.DistributedDataParallel(decoder, device_ids=[local_rank], output_device=local_rank)
|
| 393 |
+
criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx['<pad>'])
|
| 394 |
+
if CONFIG['fine_tune_encoder']:
|
| 395 |
+
params = list(decoder.parameters()) + list(encoder.parameters())
|
| 396 |
+
else:
|
| 397 |
+
params = list(decoder.parameters())
|
| 398 |
+
optimizer = optim.Adam(params, lr=CONFIG['lr'])
|
| 399 |
+
# Initialize training state variables
|
| 400 |
+
start_epoch = 0
|
| 401 |
+
best_val_loss = float('inf')
|
| 402 |
+
epochs_without_improvement = 0
|
| 403 |
+
# Resume from checkpoint if provided
|
| 404 |
+
if resume_checkpoint is not None:
|
| 405 |
+
print(f"Loading checkpoint from {resume_checkpoint}")
|
| 406 |
+
# Allow Vocabulary as a safe global so it can be unpickled
|
| 407 |
+
torch.serialization.add_safe_globals([Vocabulary])
|
| 408 |
+
checkpoint = torch.load(resume_checkpoint, map_location=device, weights_only=False)
|
| 409 |
+
encoder.load_state_dict(checkpoint['encoder'])
|
| 410 |
+
decoder.load_state_dict(checkpoint['decoder'])
|
| 411 |
+
if 'optimizer' in checkpoint:
|
| 412 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 413 |
+
else:
|
| 414 |
+
print("Warning: 'optimizer' state not found in checkpoint. Starting with fresh optimizer state.")
|
| 415 |
+
start_epoch = checkpoint['epoch'] + 1
|
| 416 |
+
best_val_loss = checkpoint.get('best_val_loss', float('inf'))
|
| 417 |
+
epochs_without_improvement = checkpoint.get('epochs_without_improvement', 0)
|
| 418 |
+
print(f"Resumed training from epoch {start_epoch}")
|
| 419 |
+
for epoch in range(start_epoch, CONFIG['num_epochs']):
|
| 420 |
+
if distributed:
|
| 421 |
+
train_sampler.set_epoch(epoch)
|
| 422 |
+
encoder.train()
|
| 423 |
+
decoder.train()
|
| 424 |
+
total_loss = 0
|
| 425 |
+
for imgs, caps in tqdm(train_loader):
|
| 426 |
+
imgs = imgs.to(device)
|
| 427 |
+
caps = caps.to(device)
|
| 428 |
+
optimizer.zero_grad()
|
| 429 |
+
features = encoder(imgs)
|
| 430 |
+
outputs = decoder(features, caps)
|
| 431 |
+
loss = criterion(outputs.view(-1, len(vocab.word2idx)),
|
| 432 |
+
caps[:, 1:].reshape(-1))
|
| 433 |
+
loss.backward()
|
| 434 |
+
if CONFIG['grad_clip'] is not None:
|
| 435 |
+
nn.utils.clip_grad_norm_(decoder.parameters(), CONFIG['grad_clip'])
|
| 436 |
+
optimizer.step()
|
| 437 |
+
total_loss += loss.item()
|
| 438 |
+
if epoch % 5 == 0:
|
| 439 |
+
val_loss, metrics = evaluate(encoder, decoder, val_loader, device, criterion, compute_extended=True)
|
| 440 |
+
if local_rank == 0:
|
| 441 |
+
print(f"Epoch {epoch+1}/{CONFIG['num_epochs']} | Train Loss: {total_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f}")
|
| 442 |
+
with open("metrics_log_Resnet.txt", "a") as f:
|
| 443 |
+
f.write(f"Epoch {epoch+1}: {metrics}\n")
|
| 444 |
+
else:
|
| 445 |
+
val_loss = evaluate(encoder, decoder, val_loader, device, criterion, compute_extended=False)
|
| 446 |
+
if local_rank == 0:
|
| 447 |
+
print(f"Epoch {epoch+1}/{CONFIG['num_epochs']} | Train Loss: {total_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f}")
|
| 448 |
+
if local_rank == 0:
|
| 449 |
+
if val_loss < best_val_loss:
|
| 450 |
+
best_val_loss = val_loss
|
| 451 |
+
epochs_without_improvement = 0
|
| 452 |
+
checkpoint_path = f'caption_model_best_epoch{epoch}.pth'
|
| 453 |
+
torch.save({
|
| 454 |
+
'epoch': epoch,
|
| 455 |
+
'encoder': encoder.state_dict(),
|
| 456 |
+
'decoder': decoder.state_dict(),
|
| 457 |
+
'optimizer': optimizer.state_dict(),
|
| 458 |
+
'best_val_loss': best_val_loss,
|
| 459 |
+
'epochs_without_improvement': epochs_without_improvement,
|
| 460 |
+
'vocab': vocab,
|
| 461 |
+
'config': CONFIG
|
| 462 |
+
}, checkpoint_path)
|
| 463 |
+
#upload_files(epoch)
|
| 464 |
+
else:
|
| 465 |
+
epochs_without_improvement += 1
|
| 466 |
+
if epochs_without_improvement >= 3:
|
| 467 |
+
print("Early stopping triggered.")
|
| 468 |
+
break
|
| 469 |
+
|
| 470 |
+
def upload_files(i):
|
| 471 |
+
files = [f"caption_model_best_epoch{i}.pth", "metrics_log_Resnet.txt"]
|
| 472 |
+
for file in files:
|
| 473 |
+
result = subprocess.run(
|
| 474 |
+
["rclone", "copy", file, "onedrive:/Computer_Viz/"],
|
| 475 |
+
capture_output=True, text=True
|
| 476 |
+
)
|
| 477 |
+
if result.returncode == 0:
|
| 478 |
+
print(f"{file} uploaded successfully.")
|
| 479 |
+
else:
|
| 480 |
+
print(f"Error during upload of {file}:", result.stderr)
|
| 481 |
+
|
| 482 |
+
if __name__ == '__main__':
|
| 483 |
+
parser = argparse.ArgumentParser()
|
| 484 |
+
parser.add_argument("--distributed", action="store_true", help="Enable distributed training")
|
| 485 |
+
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume training")
|
| 486 |
+
args = parser.parse_args()
|
| 487 |
+
if args.distributed:
|
| 488 |
+
setup_distributed()
|
| 489 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 490 |
+
torch.cuda.set_device(local_rank)
|
| 491 |
+
device = torch.device("cuda", local_rank)
|
| 492 |
+
else:
|
| 493 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 494 |
+
local_rank = 0
|
| 495 |
+
train(distributed=args.distributed, local_rank=local_rank, device=device, resume_checkpoint=args.resume)
|
| 496 |
+
if args.distributed:
|
| 497 |
+
cleanup_distributed()
|
training/train_advanced.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Training Script with Best Practices
|
| 3 |
+
- Learning rate scheduling
|
| 4 |
+
- Mixed precision training
|
| 5 |
+
- Experiment tracking (W&B optional)
|
| 6 |
+
- Comprehensive evaluation
|
| 7 |
+
- Model checkpointing
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import os
|
| 12 |
+
import random
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.optim as optim
|
| 17 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, LambdaLR
|
| 18 |
+
import math
|
| 19 |
+
from efficient_train import (
|
| 20 |
+
create_dataloaders, Encoder, Decoder, ImageCaptioningModel,
|
| 21 |
+
train_epoch, validate, generate_caption
|
| 22 |
+
)
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
|
| 25 |
+
# Optional: Weights & Biases
|
| 26 |
+
try:
|
| 27 |
+
import wandb
|
| 28 |
+
WANDB_AVAILABLE = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
WANDB_AVAILABLE = False
|
| 31 |
+
print("W&B not available. Install with: pip install wandb")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
|
| 35 |
+
"""Create learning rate schedule with warmup and cosine annealing"""
|
| 36 |
+
def lr_lambda(current_step):
|
| 37 |
+
if current_step < num_warmup_steps:
|
| 38 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 39 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 40 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| 41 |
+
return LambdaLR(optimizer, lr_lambda)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def train_advanced(args):
|
| 45 |
+
"""Advanced training with all best practices"""
|
| 46 |
+
|
| 47 |
+
# Setup
|
| 48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
+
torch.manual_seed(args.seed)
|
| 50 |
+
random.seed(args.seed)
|
| 51 |
+
np.random.seed(args.seed)
|
| 52 |
+
|
| 53 |
+
# GPU optimizations
|
| 54 |
+
if torch.cuda.is_available():
|
| 55 |
+
torch.backends.cudnn.benchmark = True # Optimize for consistent input sizes
|
| 56 |
+
torch.backends.cudnn.deterministic = False # Faster, but non-deterministic
|
| 57 |
+
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
| 58 |
+
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 59 |
+
|
| 60 |
+
# Initialize W&B
|
| 61 |
+
if args.use_wandb and WANDB_AVAILABLE:
|
| 62 |
+
wandb.init(
|
| 63 |
+
project=args.wandb_project,
|
| 64 |
+
name=f"{args.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
| 65 |
+
config=vars(args)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Create dataloaders
|
| 69 |
+
train_loader, val_loader, test_loader, tokenizer, train_set = create_dataloaders(args)
|
| 70 |
+
|
| 71 |
+
# Initialize model
|
| 72 |
+
encoder = Encoder(args.model_name, args.embed_dim)
|
| 73 |
+
decoder = Decoder(
|
| 74 |
+
vocab_size=tokenizer.vocab_size + 2,
|
| 75 |
+
embed_dim=args.embed_dim,
|
| 76 |
+
num_layers=args.num_layers,
|
| 77 |
+
num_heads=args.num_heads,
|
| 78 |
+
max_seq_length=64,
|
| 79 |
+
dropout=args.dropout
|
| 80 |
+
)
|
| 81 |
+
model = ImageCaptioningModel(encoder, decoder).to(device)
|
| 82 |
+
|
| 83 |
+
# Resume from checkpoint if provided
|
| 84 |
+
start_epoch = 0
|
| 85 |
+
best_val_loss = float('inf')
|
| 86 |
+
best_metrics = {}
|
| 87 |
+
|
| 88 |
+
if args.resume_checkpoint:
|
| 89 |
+
print(f"Loading checkpoint from {args.resume_checkpoint}")
|
| 90 |
+
# Handle PyTorch 2.6+ security: allow tokenizer classes
|
| 91 |
+
try:
|
| 92 |
+
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
|
| 93 |
+
torch.serialization.add_safe_globals([GPT2TokenizerFast])
|
| 94 |
+
except ImportError:
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
checkpoint = torch.load(args.resume_checkpoint, map_location=device, weights_only=False)
|
| 98 |
+
model.load_state_dict(checkpoint['model_state'])
|
| 99 |
+
start_epoch = checkpoint.get('epoch', 0) + 1
|
| 100 |
+
best_val_loss = checkpoint.get('val_loss', float('inf'))
|
| 101 |
+
print(f"Resumed from epoch {start_epoch}, best val loss: {best_val_loss:.4f}")
|
| 102 |
+
|
| 103 |
+
# Optimizer with different learning rates for encoder/decoder
|
| 104 |
+
encoder_params = [p for n, p in model.named_parameters() if 'encoder' in n]
|
| 105 |
+
decoder_params = [p for n, p in model.named_parameters() if 'decoder' in n]
|
| 106 |
+
|
| 107 |
+
if args.different_lr:
|
| 108 |
+
# Lower learning rate for encoder (fine-tuning)
|
| 109 |
+
optimizer = optim.AdamW([
|
| 110 |
+
{'params': encoder_params, 'lr': args.lr * 0.1},
|
| 111 |
+
{'params': decoder_params, 'lr': args.lr}
|
| 112 |
+
], weight_decay=args.weight_decay)
|
| 113 |
+
else:
|
| 114 |
+
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
| 115 |
+
|
| 116 |
+
# Learning rate scheduler
|
| 117 |
+
if args.scheduler == 'cosine':
|
| 118 |
+
scheduler = CosineAnnealingLR(
|
| 119 |
+
optimizer,
|
| 120 |
+
T_max=args.epochs * len(train_loader),
|
| 121 |
+
eta_min=args.min_lr
|
| 122 |
+
)
|
| 123 |
+
elif args.scheduler == 'plateau':
|
| 124 |
+
scheduler = ReduceLROnPlateau(
|
| 125 |
+
optimizer, mode='min', factor=0.5, patience=args.patience
|
| 126 |
+
)
|
| 127 |
+
elif args.scheduler == 'warmup_cosine':
|
| 128 |
+
num_training_steps = args.epochs * len(train_loader)
|
| 129 |
+
num_warmup_steps = args.warmup_epochs * len(train_loader)
|
| 130 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 131 |
+
optimizer, num_warmup_steps, num_training_steps
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
scheduler = None
|
| 135 |
+
|
| 136 |
+
# Loss function
|
| 137 |
+
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
|
| 138 |
+
|
| 139 |
+
# Mixed precision training - Use new API for PyTorch 2.6+
|
| 140 |
+
if hasattr(torch.amp, 'GradScaler'):
|
| 141 |
+
scaler = torch.amp.GradScaler('cuda', enabled=args.use_amp)
|
| 142 |
+
else:
|
| 143 |
+
scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
|
| 144 |
+
|
| 145 |
+
# Create checkpoint directory
|
| 146 |
+
os.makedirs(args.checkpoint_dir, exist_ok=True)
|
| 147 |
+
|
| 148 |
+
# Training loop
|
| 149 |
+
patience_counter = 0
|
| 150 |
+
|
| 151 |
+
for epoch in range(start_epoch, args.epochs):
|
| 152 |
+
args.epoch = epoch # Set epoch for train_epoch function
|
| 153 |
+
print(f"\nEpoch {epoch+1}/{args.epochs}")
|
| 154 |
+
print("-" * 60)
|
| 155 |
+
|
| 156 |
+
# Train
|
| 157 |
+
train_loss = train_epoch(
|
| 158 |
+
model, train_loader, optimizer, criterion, scaler,
|
| 159 |
+
scheduler if args.scheduler == 'cosine' or args.scheduler == 'warmup_cosine' else None,
|
| 160 |
+
device, args
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Validate
|
| 164 |
+
val_loss = validate(model, val_loader, criterion, device)
|
| 165 |
+
|
| 166 |
+
# Update scheduler
|
| 167 |
+
if args.scheduler == 'plateau':
|
| 168 |
+
scheduler.step(val_loss)
|
| 169 |
+
elif args.scheduler in ['cosine', 'warmup_cosine']:
|
| 170 |
+
# Already updated in train_epoch
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
current_lr = optimizer.param_groups[0]['lr']
|
| 174 |
+
|
| 175 |
+
print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")
|
| 176 |
+
|
| 177 |
+
# Log to W&B
|
| 178 |
+
log_dict = {
|
| 179 |
+
'epoch': epoch,
|
| 180 |
+
'train_loss': train_loss,
|
| 181 |
+
'val_loss': val_loss,
|
| 182 |
+
'learning_rate': current_lr
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
if args.use_wandb and WANDB_AVAILABLE:
|
| 186 |
+
wandb.log(log_dict)
|
| 187 |
+
|
| 188 |
+
# Checkpointing
|
| 189 |
+
is_best = val_loss < best_val_loss
|
| 190 |
+
|
| 191 |
+
if is_best:
|
| 192 |
+
best_val_loss = val_loss
|
| 193 |
+
patience_counter = 0
|
| 194 |
+
|
| 195 |
+
# Save best model
|
| 196 |
+
checkpoint = {
|
| 197 |
+
'epoch': epoch,
|
| 198 |
+
'model_state': model.state_dict(),
|
| 199 |
+
'optimizer_state': optimizer.state_dict(),
|
| 200 |
+
'scheduler_state': scheduler.state_dict() if scheduler else None,
|
| 201 |
+
'val_loss': val_loss,
|
| 202 |
+
'train_loss': train_loss,
|
| 203 |
+
'tokenizer': tokenizer,
|
| 204 |
+
'config': vars(args)
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
best_path = os.path.join(args.checkpoint_dir, 'best_model.pth')
|
| 208 |
+
torch.save(checkpoint, best_path)
|
| 209 |
+
print(f"✓ Saved best model (val_loss: {val_loss:.4f})")
|
| 210 |
+
|
| 211 |
+
else:
|
| 212 |
+
patience_counter += 1
|
| 213 |
+
|
| 214 |
+
# Save periodic checkpoints
|
| 215 |
+
if (epoch + 1) % args.save_every == 0:
|
| 216 |
+
checkpoint = {
|
| 217 |
+
'epoch': epoch,
|
| 218 |
+
'model_state': model.state_dict(),
|
| 219 |
+
'optimizer_state': optimizer.state_dict(),
|
| 220 |
+
'scheduler_state': scheduler.state_dict() if scheduler else None,
|
| 221 |
+
'val_loss': val_loss,
|
| 222 |
+
'train_loss': train_loss,
|
| 223 |
+
'tokenizer': tokenizer,
|
| 224 |
+
'config': vars(args)
|
| 225 |
+
}
|
| 226 |
+
checkpoint_path = os.path.join(
|
| 227 |
+
args.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'
|
| 228 |
+
)
|
| 229 |
+
torch.save(checkpoint, checkpoint_path)
|
| 230 |
+
print(f"✓ Saved periodic checkpoint (epoch {epoch+1})")
|
| 231 |
+
|
| 232 |
+
# Early stopping
|
| 233 |
+
if patience_counter >= args.early_stopping_patience:
|
| 234 |
+
print(f"\nEarly stopping triggered after {args.early_stopping_patience} epochs without improvement")
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
print("\n" + "="*60)
|
| 238 |
+
print("Training Complete!")
|
| 239 |
+
print(f"Best validation loss: {best_val_loss:.4f}")
|
| 240 |
+
print(f"Best model saved to: {os.path.join(args.checkpoint_dir, 'best_model.pth')}")
|
| 241 |
+
print("="*60)
|
| 242 |
+
|
| 243 |
+
if args.use_wandb and WANDB_AVAILABLE:
|
| 244 |
+
wandb.finish()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def main():
|
| 248 |
+
parser = argparse.ArgumentParser(description='Advanced training with best practices')
|
| 249 |
+
|
| 250 |
+
# Data arguments
|
| 251 |
+
parser.add_argument('--train_image_dir', type=str, required=True)
|
| 252 |
+
parser.add_argument('--train_ann_file', type=str, required=True)
|
| 253 |
+
parser.add_argument('--val_image_dir', type=str, required=True)
|
| 254 |
+
parser.add_argument('--val_ann_file', type=str, required=True)
|
| 255 |
+
parser.add_argument('--test_image_dir', type=str, required=True)
|
| 256 |
+
|
| 257 |
+
# Model arguments
|
| 258 |
+
parser.add_argument('--model_name', type=str, default='efficientnet_b3')
|
| 259 |
+
parser.add_argument('--embed_dim', type=int, default=512)
|
| 260 |
+
parser.add_argument('--num_layers', type=int, default=8)
|
| 261 |
+
parser.add_argument('--num_heads', type=int, default=8)
|
| 262 |
+
parser.add_argument('--dropout', type=float, default=0.1)
|
| 263 |
+
|
| 264 |
+
# Training arguments
|
| 265 |
+
parser.add_argument('--batch_size', type=int, default=96)
|
| 266 |
+
parser.add_argument('--lr', type=float, default=3e-4)
|
| 267 |
+
parser.add_argument('--epochs', type=int, default=20)
|
| 268 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 269 |
+
parser.add_argument('--use_amp', action='store_true', help='Use mixed precision')
|
| 270 |
+
parser.add_argument('--grad_accum', type=int, default=1)
|
| 271 |
+
parser.add_argument('--weight_decay', type=float, default=1e-4)
|
| 272 |
+
parser.add_argument('--different_lr', action='store_true',
|
| 273 |
+
help='Use different LR for encoder/decoder')
|
| 274 |
+
|
| 275 |
+
# Scheduler arguments
|
| 276 |
+
parser.add_argument('--scheduler', type=str, default='plateau',
|
| 277 |
+
choices=['cosine', 'plateau', 'warmup_cosine', 'none'])
|
| 278 |
+
parser.add_argument('--patience', type=int, default=3)
|
| 279 |
+
parser.add_argument('--min_lr', type=float, default=1e-6)
|
| 280 |
+
parser.add_argument('--warmup_epochs', type=int, default=2)
|
| 281 |
+
|
| 282 |
+
# Checkpointing
|
| 283 |
+
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
|
| 284 |
+
parser.add_argument('--resume_checkpoint', type=str, default=None)
|
| 285 |
+
parser.add_argument('--save_every', type=int, default=5)
|
| 286 |
+
parser.add_argument('--early_stopping_patience', type=int, default=5)
|
| 287 |
+
|
| 288 |
+
# Experiment tracking
|
| 289 |
+
parser.add_argument('--use_wandb', action='store_true', help='Use Weights & Biases')
|
| 290 |
+
parser.add_argument('--wandb_project', type=str, default='image-captioning')
|
| 291 |
+
|
| 292 |
+
# Additional args needed by create_dataloaders and train_epoch
|
| 293 |
+
parser.add_argument('--distributed', action='store_true', help='Use distributed training')
|
| 294 |
+
parser.add_argument('--local_rank', type=int, default=0, help='Local rank for distributed training')
|
| 295 |
+
|
| 296 |
+
args = parser.parse_args()
|
| 297 |
+
|
| 298 |
+
# Set epoch attribute (will be updated during training)
|
| 299 |
+
args.epoch = 0
|
| 300 |
+
|
| 301 |
+
train_advanced(args)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if __name__ == '__main__':
|
| 305 |
+
main()
|
| 306 |
+
|