AOUNZakaria commited on
Commit
32d4a86
·
1 Parent(s): 400b4a4

Deploy image captioner

Browse files
README.md CHANGED
@@ -1,11 +1,41 @@
1
  ---
2
- title: ImageCaptionner
3
- emoji: 🚀
4
- colorFrom: indigo
5
  colorTo: purple
6
  sdk: docker
 
 
7
  pinned: false
8
- license: apache-2.0
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Image Caption Generator
3
+ emoji: 🖼️
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: docker
7
+ sdk_version: latest
8
+ app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Image Caption Generator
14
+
15
+ Generate captions for images using an optimized EfficientNet-B3 model.
16
+
17
+ ## Features
18
+
19
+ - ✅ EfficientNet-B3 model for high-quality captions
20
+ - ✅ Optimized quantized model (~245MB)
21
+ - ✅ Fast inference
22
+ - ✅ Simple web interface
23
+
24
+ ## How to Use
25
+
26
+ 1. Upload an image (PNG, JPG, JPEG)
27
+ 2. Click "Generate Caption"
28
+ 3. Get your caption!
29
+
30
+ ## Model
31
+
32
+ - **Architecture:** EfficientNet-B3
33
+ - **Optimization:** INT8 Quantization
34
+ - **Size:** ~245MB
35
+
36
+ ## Technical Details
37
+
38
+ - Built with PyTorch and Transformers
39
+ - Uses GPT-2 tokenizer
40
+ - Optimized for production deployment
41
+
app/__init__.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image Caption Generator - Flask Application
3
+ Production-ready application with model caching and security.
4
+ """
5
+
6
+ from flask import Flask
7
+ import os
8
+ import logging
9
+
10
+ # Configure logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def create_app(config=None):
19
+ """
20
+ Application factory pattern.
21
+ Creates and configures the Flask application.
22
+ """
23
+ # Get base directory (project root)
24
+ import os
25
+ from pathlib import Path
26
+ base_dir = Path(__file__).resolve().parent.parent
27
+
28
+ app = Flask(__name__,
29
+ template_folder=str(base_dir / 'templates'),
30
+ static_folder=str(base_dir / 'static'))
31
+
32
+ # Load configuration
33
+ app.secret_key = os.environ.get("SESSION_SECRET")
34
+ if not app.secret_key or app.secret_key == "default-secret-key":
35
+ if os.environ.get("FLASK_ENV") == "production":
36
+ raise ValueError("SESSION_SECRET must be set in production environment!")
37
+ else:
38
+ logger.warning("Using default secret key. Set SESSION_SECRET in production!")
39
+ app.secret_key = "default-secret-key"
40
+
41
+ # Configuration
42
+ app.config['UPLOAD_FOLDER'] = os.environ.get('UPLOAD_FOLDER', 'uploads')
43
+ app.config['MAX_CONTENT_LENGTH'] = int(os.environ.get('MAX_FILE_SIZE', 10 * 1024 * 1024))
44
+ app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg'}
45
+
46
+ # Create uploads directory
47
+ if not os.path.exists(app.config['UPLOAD_FOLDER']):
48
+ os.makedirs(app.config['UPLOAD_FOLDER'])
49
+
50
+ # Register blueprints/routes
51
+ from app.routes import bp
52
+ app.register_blueprint(bp)
53
+
54
+ # Download model if needed (before loading)
55
+ # Try HF Hub first, then download URL
56
+ if os.environ.get("FLASK_ENV") == "production" or os.environ.get("LOAD_MODELS", "true").lower() == "true":
57
+ try:
58
+ # Try downloading from Hugging Face Hub first
59
+ model_repo = os.environ.get("HF_MODEL_REPO")
60
+ if model_repo:
61
+ try:
62
+ from huggingface_hub import hf_hub_download
63
+ logger.info(f"Downloading model from HF Hub: {model_repo}")
64
+ model_path = hf_hub_download(
65
+ repo_id=model_repo,
66
+ filename="efficientnet_efficient_best_model_quantized.pth",
67
+ cache_dir=str(base_dir / "models" / "optimized_models")
68
+ )
69
+ logger.info(f"Model downloaded from HF Hub: {model_path}")
70
+ except Exception as e:
71
+ logger.warning(f"Could not download from HF Hub: {e}. Trying download URL...")
72
+ import sys
73
+ sys.path.insert(0, str(base_dir))
74
+ from scripts.download_model import download_efficientnet_model
75
+ download_efficientnet_model()
76
+ else:
77
+ # Fallback to download URL method
78
+ import sys
79
+ sys.path.insert(0, str(base_dir))
80
+ from scripts.download_model import download_efficientnet_model
81
+ download_efficientnet_model()
82
+ except Exception as e:
83
+ logger.warning(f"Could not download model: {e}. Will try to use existing model if available.")
84
+
85
+ # Initialize models at startup (production)
86
+ if os.environ.get("FLASK_ENV") == "production" or os.environ.get("LOAD_MODELS", "true").lower() == "true":
87
+ logger.info("Initializing models...")
88
+ try:
89
+ from app.utils.model_cache import model_cache
90
+ # Only load EfficientNet model
91
+ model_cache.load_efficientnet_model_only(use_optimized=True)
92
+ logger.info("EfficientNet model loaded successfully")
93
+ except Exception as e:
94
+ logger.error(f"Failed to load models: {e}", exc_info=True)
95
+ # Don't raise here - let the app start and handle errors gracefully
96
+
97
+ return app
98
+
99
+
100
+ # For backward compatibility
101
+ app = create_app()
102
+
app/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (3.47 kB). View file
 
app/__pycache__/config.cpython-314.pyc ADDED
Binary file (1.69 kB). View file
 
app/__pycache__/routes.cpython-314.pyc ADDED
Binary file (9.07 kB). View file
 
app/config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Application configuration.
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+
8
+ # Base directory
9
+ BASE_DIR = Path(__file__).resolve().parent.parent
10
+
11
+ # Flask configuration
12
+ SECRET_KEY = os.environ.get("SESSION_SECRET", "dev-secret-key-change-in-production")
13
+ FLASK_ENV = os.environ.get("FLASK_ENV", "development")
14
+ DEBUG = FLASK_ENV != "production"
15
+
16
+ # Upload configuration
17
+ UPLOAD_FOLDER = os.environ.get("UPLOAD_FOLDER", str(BASE_DIR / "uploads"))
18
+ MAX_FILE_SIZE = int(os.environ.get("MAX_FILE_SIZE", 10 * 1024 * 1024)) # 10MB
19
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
20
+
21
+ # Model paths
22
+ MODELS_DIR = BASE_DIR / "models"
23
+ OPTIMIZED_MODELS_DIR = MODELS_DIR / "optimized_models"
24
+ RESNET_MODEL_PATH = MODELS_DIR / "resnet_best_model.pth"
25
+ EFFICIENTNET_MODEL_PATH = MODELS_DIR / "efficient_best_model.pth"
26
+ VOCAB_PATH = MODELS_DIR / "vocab.pkl"
27
+
28
+ # Model configuration
29
+ USE_OPTIMIZED_MODELS = os.environ.get("USE_OPTIMIZED_MODELS", "true").lower() == "true"
30
+ LOAD_MODELS_ON_STARTUP = os.environ.get("LOAD_MODELS", "true").lower() == "true"
31
+
app/routes.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Application routes.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ import time
8
+ from datetime import datetime
9
+ from flask import Blueprint, render_template, request, jsonify
10
+ from werkzeug.utils import secure_filename
11
+ from torchvision import transforms
12
+ from PIL import Image
13
+ import torch
14
+
15
+ from app.utils.model_cache import model_cache
16
+ from app.config import MAX_FILE_SIZE, ALLOWED_EXTENSIONS
17
+
18
+ # Import training functions (handle both old and new locations)
19
+ try:
20
+ from training.resnet_train import visualize_attention
21
+ from training.efficient_train import generate_caption
22
+ except ImportError:
23
+ # Fallback for backward compatibility (before reorganization)
24
+ from resnet_train import visualize_attention
25
+ from efficient_train import generate_caption
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ bp = Blueprint('main', __name__)
30
+
31
+ # Image transformation for EfficientNet
32
+ efficientnet_transform = transforms.Compose([
33
+ transforms.Resize(224),
34
+ transforms.CenterCrop(224),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
37
+ ])
38
+
39
+
40
+ def allowed_file(filename):
41
+ """Check if file extension is allowed."""
42
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
43
+
44
+
45
+ def validate_file_type(file_path):
46
+ """Validate file is actually an image (not just extension)."""
47
+ try:
48
+ img = Image.open(file_path)
49
+ img.verify()
50
+ return True
51
+ except Exception:
52
+ return False
53
+
54
+
55
+ @bp.before_request
56
+ def before_request():
57
+ """Log request start time."""
58
+ request.start_time = time.time()
59
+
60
+
61
+ @bp.after_request
62
+ def after_request(response):
63
+ """Add security headers and log request duration."""
64
+ # Security headers
65
+ response.headers['X-Content-Type-Options'] = 'nosniff'
66
+ response.headers['X-Frame-Options'] = 'DENY'
67
+ response.headers['X-XSS-Protection'] = '1; mode=block'
68
+
69
+ # Log request
70
+ duration = time.time() - request.start_time
71
+ logger.info(f"{request.method} {request.path} - {response.status_code} - {duration:.3f}s")
72
+
73
+ return response
74
+
75
+
76
+ @bp.route('/')
77
+ def index():
78
+ """Serve the main page."""
79
+ return render_template('index.html')
80
+
81
+
82
+ @bp.route('/health')
83
+ def health_check():
84
+ """Health check endpoint for load balancers."""
85
+ return jsonify({
86
+ 'status': 'healthy',
87
+ 'timestamp': datetime.utcnow().isoformat(),
88
+ 'models_loaded': {
89
+ 'resnet': model_cache.is_resnet_loaded(),
90
+ 'efficientnet': model_cache.is_efficientnet_loaded()
91
+ }
92
+ }), 200
93
+
94
+
95
+ @bp.route('/ready')
96
+ def readiness_check():
97
+ """Readiness check - ensures models are loaded."""
98
+ if not model_cache.is_resnet_loaded() and not model_cache.is_efficientnet_loaded():
99
+ return jsonify({'status': 'not ready', 'reason': 'models not loaded'}), 503
100
+ return jsonify({'status': 'ready'}), 200
101
+
102
+
103
+ @bp.route('/upload', methods=['POST'])
104
+ def upload_file():
105
+ """Handle image upload and generate caption."""
106
+ if 'image' not in request.files:
107
+ logger.warning("Upload request missing 'image' field")
108
+ return jsonify({'error': 'No file part'}), 400
109
+
110
+ file = request.files['image']
111
+ model_choice = request.form.get('model', 'efficientnet') # Default to EfficientNet
112
+
113
+ if file.filename == '':
114
+ return jsonify({'error': 'No selected file'}), 400
115
+
116
+ if not file or not allowed_file(file.filename):
117
+ return jsonify({'error': 'Invalid file type. Only PNG, JPG, JPEG allowed.'}), 400
118
+
119
+ # Get upload folder from current app (set in __init__.py)
120
+ from flask import current_app
121
+ upload_folder = current_app.config['UPLOAD_FOLDER']
122
+
123
+ # Save file temporarily
124
+ filename = secure_filename(file.filename)
125
+ filepath = os.path.join(upload_folder, filename)
126
+
127
+ try:
128
+ file.save(filepath)
129
+
130
+ # Validate file size
131
+ file_size = os.path.getsize(filepath)
132
+ if file_size > MAX_FILE_SIZE:
133
+ os.remove(filepath)
134
+ return jsonify({'error': f'File too large. Maximum size: {MAX_FILE_SIZE / 1024 / 1024}MB'}), 400
135
+
136
+ # Validate file is actually an image
137
+ if not validate_file_type(filepath):
138
+ os.remove(filepath)
139
+ return jsonify({'error': 'Invalid image file'}), 400
140
+
141
+ # Generate caption based on model choice
142
+ start_time = time.time()
143
+
144
+ if model_choice == 'efficientnet':
145
+ if not model_cache.is_efficientnet_loaded():
146
+ return jsonify({'error': 'EfficientNet model not available'}), 503
147
+
148
+ model, tokenizer = model_cache.get_efficientnet_model()
149
+
150
+ # Load and preprocess image
151
+ image = Image.open(filepath).convert('RGB')
152
+ image_tensor = efficientnet_transform(image).to(model_cache._device)
153
+
154
+ # Generate caption
155
+ with torch.no_grad():
156
+ caption = generate_caption(
157
+ model,
158
+ image_tensor,
159
+ tokenizer,
160
+ model_cache._device,
161
+ max_length=64
162
+ )
163
+ else: # resnet50
164
+ if not model_cache.is_resnet_loaded():
165
+ return jsonify({'error': 'ResNet model not available'}), 503
166
+
167
+ encoder, decoder, vocab = model_cache.get_resnet_models()
168
+
169
+ # Generate caption
170
+ with torch.no_grad():
171
+ caption = visualize_attention(filepath, encoder, decoder, model_cache._device)
172
+
173
+ inference_time = time.time() - start_time
174
+ logger.info(f"Caption generated in {inference_time:.3f}s using {model_choice}")
175
+
176
+ # Clean up uploaded file
177
+ os.remove(filepath)
178
+
179
+ return jsonify({
180
+ 'success': True,
181
+ 'caption': caption,
182
+ 'model': model_choice,
183
+ 'inference_time': round(inference_time, 3)
184
+ })
185
+
186
+ except Exception as e:
187
+ logger.error(f"Error generating caption: {e}", exc_info=True)
188
+
189
+ # Clean up file on error
190
+ if os.path.exists(filepath):
191
+ os.remove(filepath)
192
+
193
+ return jsonify({'error': 'Failed to generate caption. Please try again.'}), 500
194
+
app/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Utilities package."""
2
+
app/utils/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (198 Bytes). View file
 
app/utils/__pycache__/model_cache.cpython-314.pyc ADDED
Binary file (16.1 kB). View file
 
app/utils/model_cache.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Caching Module for Production
3
+ Loads models once at startup and reuses them for all requests.
4
+ This eliminates the overhead of loading models per-request.
5
+ """
6
+
7
+ import torch
8
+ import os
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Get base directory (project root)
15
+ BASE_DIR = Path(__file__).resolve().parent.parent.parent
16
+ MODELS_DIR = BASE_DIR / "models"
17
+
18
+
19
+ class ModelCache:
20
+ """Singleton class to cache loaded models in memory."""
21
+
22
+ def __init__(self):
23
+ self._resnet_encoder = None
24
+ self._resnet_decoder = None
25
+ self._resnet_vocab = None
26
+ self._efficientnet_model = None
27
+ self._efficientnet_tokenizer = None
28
+ self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+ self._models_loaded = False
30
+
31
+ logger.info(f"ModelCache initialized on device: {self._device}")
32
+
33
+ def load_all_models(self,
34
+ resnet_path=None,
35
+ efficientnet_path=None,
36
+ use_optimized=True):
37
+ """
38
+ Load all models at startup.
39
+
40
+ Args:
41
+ resnet_path: Path to ResNet checkpoint (default: models/resnet_best_model.pth)
42
+ efficientnet_path: Path to EfficientNet checkpoint (default: models/efficient_best_model.pth)
43
+ use_optimized: If True, try to load optimized models first
44
+ """
45
+ if self._models_loaded:
46
+ logger.warning("Models already loaded, skipping")
47
+ return
48
+
49
+ # Set default paths
50
+ if resnet_path is None:
51
+ resnet_path = str(MODELS_DIR / "resnet_best_model.pth")
52
+ if efficientnet_path is None:
53
+ efficientnet_path = str(MODELS_DIR / "efficient_best_model.pth")
54
+
55
+ # Try optimized models first if requested
56
+ if use_optimized:
57
+ # Check multiple possible locations for optimized models
58
+ optimized_resnet_paths = [
59
+ str(MODELS_DIR / "optimized_models" / "resnet_resnet_best_model_quantized.pth"),
60
+ str(BASE_DIR / "optimized_models" / "resnet_resnet_best_model_quantized.pth"),
61
+ resnet_path.replace('.pth', '_quantized.pth'),
62
+ resnet_path.replace('resnet_best_model.pth', 'resnet_resnet_best_model_quantized.pth'),
63
+ ]
64
+
65
+ optimized_efficient_paths = [
66
+ str(MODELS_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
67
+ str(BASE_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
68
+ efficientnet_path.replace('.pth', '_quantized.pth'),
69
+ efficientnet_path.replace('efficient_best_model.pth', 'efficientnet_efficient_best_model_quantized.pth'),
70
+ ]
71
+
72
+ # Find optimized ResNet model
73
+ for opt_path in optimized_resnet_paths:
74
+ if os.path.exists(opt_path):
75
+ resnet_path = opt_path
76
+ logger.info(f"Using optimized ResNet model: {resnet_path}")
77
+ break
78
+
79
+ # Find optimized EfficientNet model
80
+ for opt_path in optimized_efficient_paths:
81
+ if os.path.exists(opt_path):
82
+ efficientnet_path = opt_path
83
+ logger.info(f"Using optimized EfficientNet model: {efficientnet_path}")
84
+ break
85
+
86
+ # Load EfficientNet only (ResNet skipped)
87
+ try:
88
+ self.load_efficientnet_model(efficientnet_path)
89
+ logger.info("EfficientNet model loaded successfully")
90
+ except Exception as e:
91
+ logger.error(f"Failed to load EfficientNet model: {e}", exc_info=True)
92
+
93
+ self._models_loaded = True
94
+
95
+ def load_efficientnet_model_only(self, use_optimized=True):
96
+ """
97
+ Load only EfficientNet model (skip ResNet).
98
+ Useful when only EfficientNet is needed.
99
+ """
100
+ if self._models_loaded:
101
+ logger.warning("Models already loaded, skipping")
102
+ return
103
+
104
+ efficientnet_path = str(MODELS_DIR / "efficient_best_model.pth")
105
+
106
+ # Try optimized model first if requested
107
+ if use_optimized:
108
+ optimized_efficient_paths = [
109
+ str(MODELS_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
110
+ str(BASE_DIR / "optimized_models" / "efficientnet_efficient_best_model_quantized.pth"),
111
+ efficientnet_path.replace('.pth', '_quantized.pth'),
112
+ efficientnet_path.replace('efficient_best_model.pth', 'efficientnet_efficient_best_model_quantized.pth'),
113
+ ]
114
+
115
+ # Find optimized EfficientNet model
116
+ for opt_path in optimized_efficient_paths:
117
+ if os.path.exists(opt_path):
118
+ efficientnet_path = opt_path
119
+ logger.info(f"Using optimized EfficientNet model: {efficientnet_path}")
120
+ break
121
+
122
+ # Load EfficientNet
123
+ try:
124
+ self.load_efficientnet_model(efficientnet_path)
125
+ logger.info("EfficientNet model loaded successfully")
126
+ except Exception as e:
127
+ logger.error(f"Failed to load EfficientNet model: {e}", exc_info=True)
128
+
129
+ self._models_loaded = True
130
+
131
+ def load_resnet_models(self, checkpoint_path=None):
132
+ """Load ResNet encoder and decoder models."""
133
+ if self._resnet_encoder is not None:
134
+ return self._resnet_encoder, self._resnet_decoder, self._resnet_vocab
135
+
136
+ if checkpoint_path is None:
137
+ checkpoint_path = str(MODELS_DIR / "resnet_best_model.pth")
138
+
139
+ # Resolve path - try multiple locations
140
+ checkpoint_path = self._resolve_model_path(checkpoint_path)
141
+
142
+ logger.info(f"Loading ResNet models from {checkpoint_path}")
143
+
144
+ # Import from training module (handles both old and new locations)
145
+ # Need to do this BEFORE loading checkpoint to avoid pickle issues
146
+ try:
147
+ from training.resnet_train import EncoderCNN, DecoderRNN
148
+ # Add to sys.modules to help with pickle loading
149
+ import sys
150
+ if 'resnet_train' not in sys.modules:
151
+ sys.modules['resnet_train'] = sys.modules['training.resnet_train']
152
+ except ImportError:
153
+ try:
154
+ # Fallback for backward compatibility
155
+ import sys
156
+ sys.path.insert(0, str(BASE_DIR))
157
+ from resnet_train import EncoderCNN, DecoderRNN
158
+ except ImportError:
159
+ logger.error("Could not import ResNet model classes. Make sure resnet_train.py exists in training/ or root.")
160
+ raise
161
+
162
+ # Load checkpoint with proper module mapping
163
+ import sys
164
+ import importlib.util
165
+
166
+ # Map old module names for pickle compatibility
167
+ if 'resnet_train' not in sys.modules:
168
+ try:
169
+ spec = importlib.util.spec_from_file_location("resnet_train", str(BASE_DIR / "training" / "resnet_train.py"))
170
+ if spec and spec.loader:
171
+ resnet_module = importlib.util.module_from_spec(spec)
172
+ sys.modules['resnet_train'] = resnet_module
173
+ spec.loader.exec_module(resnet_module)
174
+ except Exception:
175
+ pass
176
+
177
+ checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False)
178
+
179
+ # Initialize models
180
+ self._resnet_encoder = EncoderCNN().to(self._device)
181
+ self._resnet_decoder = DecoderRNN().to(self._device)
182
+
183
+ # Load weights
184
+ self._resnet_encoder.load_state_dict(checkpoint['encoder'])
185
+ self._resnet_decoder.load_state_dict(checkpoint['decoder'])
186
+
187
+ # Set to eval mode
188
+ self._resnet_encoder.eval()
189
+ self._resnet_decoder.eval()
190
+
191
+ # Store vocabulary
192
+ self._resnet_vocab = checkpoint.get('vocab')
193
+
194
+ # Warm up models (first inference is slower)
195
+ logger.info("Warming up ResNet models...")
196
+ dummy_input = torch.randn(1, 3, 224, 224).to(self._device)
197
+ with torch.no_grad():
198
+ _ = self._resnet_encoder(dummy_input)
199
+ logger.info("ResNet models warmed up")
200
+
201
+ return self._resnet_encoder, self._resnet_decoder, self._resnet_vocab
202
+
203
+ def load_efficientnet_model(self, checkpoint_path=None):
204
+ """Load EfficientNet model."""
205
+ if self._efficientnet_model is not None:
206
+ return self._efficientnet_model, self._efficientnet_tokenizer
207
+
208
+ if checkpoint_path is None:
209
+ checkpoint_path = str(MODELS_DIR / "efficient_best_model.pth")
210
+
211
+ # Resolve path - try multiple locations
212
+ checkpoint_path = self._resolve_model_path(checkpoint_path)
213
+
214
+ logger.info(f"Loading EfficientNet model from {checkpoint_path}")
215
+
216
+ # Import from training module (handles both old and new locations)
217
+ try:
218
+ from training.efficient_train import Encoder, Decoder, ImageCaptioningModel
219
+ except ImportError:
220
+ try:
221
+ # Fallback for backward compatibility
222
+ import sys
223
+ sys.path.insert(0, str(BASE_DIR))
224
+ from efficient_train import Encoder, Decoder, ImageCaptioningModel
225
+ except ImportError:
226
+ logger.error("Could not import EfficientNet model classes. Make sure efficient_train.py exists in training/ or root.")
227
+ raise
228
+
229
+ from transformers import AutoTokenizer
230
+
231
+ # Initialize tokenizer
232
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
233
+ tokenizer.pad_token = tokenizer.eos_token
234
+ special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
235
+ tokenizer.add_special_tokens(special_tokens)
236
+ self._efficientnet_tokenizer = tokenizer
237
+
238
+ # Initialize model
239
+ encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
240
+ decoder = Decoder(
241
+ vocab_size=len(tokenizer),
242
+ embed_dim=512,
243
+ num_layers=8,
244
+ num_heads=8,
245
+ max_seq_length=64
246
+ )
247
+ self._efficientnet_model = ImageCaptioningModel(encoder, decoder).to(self._device)
248
+
249
+ # Load weights
250
+ checkpoint = torch.load(checkpoint_path, map_location=self._device, weights_only=False)
251
+
252
+ # Check if this is a quantized model (has _packed_params keys)
253
+ is_quantized = any('_packed_params' in key for key in checkpoint.get('model_state', checkpoint).keys())
254
+
255
+ if is_quantized:
256
+ # For quantized models, we need to prepare the model for quantization first
257
+ logger.info("Detected quantized model, preparing model for quantization...")
258
+ try:
259
+ # Prepare model for quantization
260
+ import torch.quantization as quant
261
+ self._efficientnet_model = quant.quantize_dynamic(
262
+ self._efficientnet_model, {torch.nn.Linear}, dtype=torch.qint8
263
+ )
264
+ logger.info("Model prepared for quantization")
265
+ except Exception as e:
266
+ logger.warning(f"Could not prepare model for quantization: {e}. Trying to load anyway...")
267
+
268
+ if 'model_state' in checkpoint:
269
+ try:
270
+ self._efficientnet_model.load_state_dict(checkpoint['model_state'], strict=False)
271
+ except Exception as e:
272
+ logger.warning(f"Could not load quantized state dict: {e}. Trying regular model...")
273
+ # Try loading non-quantized model instead
274
+ regular_path = checkpoint_path.replace('_quantized.pth', '.pth').replace('efficientnet_efficient_best_model', 'efficient_best_model')
275
+ if os.path.exists(regular_path) and regular_path != checkpoint_path:
276
+ logger.info(f"Trying regular model: {regular_path}")
277
+ checkpoint = torch.load(regular_path, map_location=self._device, weights_only=False)
278
+ if 'model_state' in checkpoint:
279
+ self._efficientnet_model.load_state_dict(checkpoint['model_state'])
280
+ else:
281
+ self._efficientnet_model.load_state_dict(checkpoint)
282
+ else:
283
+ # Fallback: try loading directly
284
+ try:
285
+ self._efficientnet_model.load_state_dict(checkpoint, strict=False)
286
+ except Exception:
287
+ logger.warning("Could not load state dict. Model may not work correctly.")
288
+
289
+ self._efficientnet_model.eval()
290
+
291
+ # Warm up
292
+ logger.info("Warming up EfficientNet model...")
293
+ dummy_input = torch.randn(1, 3, 224, 224).to(self._device)
294
+ with torch.no_grad():
295
+ _ = self._efficientnet_model.encoder(dummy_input)
296
+ logger.info("EfficientNet model warmed up")
297
+
298
+ return self._efficientnet_model, self._efficientnet_tokenizer
299
+
300
+ def _resolve_model_path(self, checkpoint_path):
301
+ """Resolve model path, trying multiple locations."""
302
+ # If path exists, use it
303
+ if os.path.exists(checkpoint_path):
304
+ return checkpoint_path
305
+
306
+ # Try in models directory
307
+ alt_path = str(MODELS_DIR / os.path.basename(checkpoint_path))
308
+ if os.path.exists(alt_path):
309
+ logger.info(f"Found model at: {alt_path}")
310
+ return alt_path
311
+
312
+ # Try in root directory (backward compatibility)
313
+ alt_path = str(BASE_DIR / os.path.basename(checkpoint_path))
314
+ if os.path.exists(alt_path):
315
+ logger.info(f"Found model at: {alt_path}")
316
+ return alt_path
317
+
318
+ # Return original path (will fail with clear error)
319
+ return checkpoint_path
320
+
321
+ def get_resnet_models(self):
322
+ """Get cached ResNet models."""
323
+ if self._resnet_encoder is None:
324
+ raise RuntimeError("ResNet models not loaded. Call load_resnet_models() first.")
325
+ return self._resnet_encoder, self._resnet_decoder, self._resnet_vocab
326
+
327
+ def get_efficientnet_model(self):
328
+ """Get cached EfficientNet model."""
329
+ if self._efficientnet_model is None:
330
+ raise RuntimeError("EfficientNet model not loaded. Call load_efficientnet_model() first.")
331
+ return self._efficientnet_model, self._efficientnet_tokenizer
332
+
333
+ def is_resnet_loaded(self):
334
+ """Check if ResNet models are loaded."""
335
+ return self._resnet_encoder is not None
336
+
337
+ def is_efficientnet_loaded(self):
338
+ """Check if EfficientNet model is loaded."""
339
+ return self._efficientnet_model is not None
340
+
341
+
342
+ # Singleton instance
343
+ model_cache = ModelCache()
hf_space_Dockerfile ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Hugging Face Spaces
2
+ # Based on: https://huggingface.co/docs/hub/spaces-sdks-docker
3
+
4
+ FROM python:3.10-slim
5
+
6
+ # Create user (HF Spaces requirement)
7
+ RUN useradd -m -u 1000 user
8
+ USER user
9
+ ENV PATH="/home/user/.local/bin:$PATH"
10
+
11
+ WORKDIR /app
12
+
13
+ # Install system dependencies
14
+ USER root
15
+ RUN apt-get update && apt-get install -y \
16
+ build-essential \
17
+ curl \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ USER user
21
+
22
+ # Copy and install Python dependencies
23
+ COPY --chown=user requirements.txt requirements.txt
24
+ RUN pip install --no-cache-dir --user --upgrade -r requirements.txt
25
+
26
+ # Download NLTK data
27
+ RUN python -c "import nltk; nltk.download('punkt', quiet=True)"
28
+
29
+ # Copy application files
30
+ COPY --chown=user app/ /app/app/
31
+ COPY --chown=user training/ /app/training/
32
+ COPY --chown=user scripts/ /app/scripts/
33
+ COPY --chown=user templates/ /app/templates/
34
+ COPY --chown=user static/ /app/static/
35
+ COPY --chown=user app.py /app/
36
+
37
+ # Create necessary directories
38
+ RUN mkdir -p /app/models/optimized_models /app/uploads
39
+
40
+ # HF Spaces uses port 7860
41
+ EXPOSE 7860
42
+
43
+ # Set environment variables
44
+ ENV FLASK_ENV=production
45
+ ENV PORT=7860
46
+
47
+ # Run the application on port 7860 (HF Spaces requirement)
48
+ # Use app.py as entry point (HF Spaces looks for app.py)
49
+ CMD ["gunicorn", "app:app", "--bind", "0.0.0.0:7860", "--workers", "1", "--timeout", "120", "--threads", "2"]
50
+
hf_space_app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Spaces - Flask Application Entry Point
3
+ HF Spaces expects app.py with an 'app' variable
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add project root to path
11
+ BASE_DIR = Path(__file__).resolve().parent
12
+ sys.path.insert(0, str(BASE_DIR))
13
+
14
+ # Import Flask app from app package
15
+ # This will trigger model loading at startup
16
+ from app import app
17
+
18
+ # HF Spaces requires 'app' variable to be available
19
+ # The app is already created in app/__init__.py
20
+ # No need to run it here - Gunicorn will handle it
21
+
hf_space_requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ transformers>=4.30.0
4
+ Pillow>=10.0.0
5
+ timm>=0.9.0
6
+ numpy>=1.24.0
7
+ flask>=2.3.0
8
+ gunicorn>=21.2.0
9
+ werkzeug>=2.3.0
10
+ nltk>=3.8.1
11
+ requests>=2.31.0
12
+ huggingface_hub>=0.20.0
13
+
scripts/download_model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download EfficientNet model from cloud storage if not present.
3
+ This script runs at application startup to download the model if needed.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import logging
9
+ from pathlib import Path
10
+ import requests
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def download_efficientnet_model():
15
+ """
16
+ Download EfficientNet optimized model if it doesn't exist.
17
+ Supports two methods:
18
+ 1. Hugging Face Hub (set HF_MODEL_REPO environment variable)
19
+ 2. Direct URL download (set EFFICIENTNET_MODEL_URL environment variable)
20
+ """
21
+ # Get base directory
22
+ base_dir = Path(__file__).resolve().parent.parent
23
+ models_dir = base_dir / "models" / "optimized_models"
24
+ models_dir.mkdir(parents=True, exist_ok=True)
25
+
26
+ model_path = models_dir / "efficientnet_efficient_best_model_quantized.pth"
27
+
28
+ # Check if model already exists
29
+ if model_path.exists():
30
+ size_mb = model_path.stat().st_size / (1024 * 1024)
31
+ logger.info(f"EfficientNet model already exists ({size_mb:.1f}MB)")
32
+ return True
33
+
34
+ # Try Hugging Face Hub first
35
+ hf_repo = os.environ.get("HF_MODEL_REPO")
36
+ if hf_repo:
37
+ try:
38
+ from huggingface_hub import hf_hub_download
39
+ logger.info(f"Downloading model from Hugging Face Hub: {hf_repo}")
40
+ downloaded_path = hf_hub_download(
41
+ repo_id=hf_repo,
42
+ filename="efficientnet_efficient_best_model_quantized.pth",
43
+ cache_dir=str(models_dir),
44
+ local_dir=str(models_dir),
45
+ local_dir_use_symlinks=False
46
+ )
47
+ # Move to expected location if needed
48
+ if downloaded_path != str(model_path):
49
+ import shutil
50
+ shutil.move(downloaded_path, model_path)
51
+ size_mb = model_path.stat().st_size / (1024 * 1024)
52
+ logger.info(f"Model downloaded from HF Hub successfully ({size_mb:.1f}MB)")
53
+ return True
54
+ except ImportError:
55
+ logger.warning("huggingface_hub not installed. Install with: pip install huggingface_hub")
56
+ except Exception as e:
57
+ logger.warning(f"Failed to download from HF Hub: {e}. Trying direct URL...")
58
+
59
+ # Fallback to direct URL download
60
+ model_url = os.environ.get("EFFICIENTNET_MODEL_URL")
61
+
62
+ if not model_url:
63
+ logger.warning("Neither HF_MODEL_REPO nor EFFICIENTNET_MODEL_URL is set.")
64
+ logger.warning("Model will not be downloaded. Set one of these environment variables.")
65
+ return False
66
+
67
+ try:
68
+ logger.info(f"Downloading EfficientNet model from {model_url}...")
69
+ logger.info("This may take a few minutes (model is ~245MB)...")
70
+
71
+ # Download with progress
72
+ response = requests.get(model_url, stream=True, timeout=300)
73
+ response.raise_for_status()
74
+
75
+ total_size = int(response.headers.get('content-length', 0))
76
+ downloaded = 0
77
+
78
+ with open(model_path, 'wb') as f:
79
+ for chunk in response.iter_content(chunk_size=8192):
80
+ if chunk:
81
+ f.write(chunk)
82
+ downloaded += len(chunk)
83
+ if total_size > 0:
84
+ percent = (downloaded / total_size) * 100
85
+ if downloaded % (10 * 1024 * 1024) == 0: # Log every 10MB
86
+ logger.info(f"Downloaded {downloaded / (1024 * 1024):.1f}MB / {total_size / (1024 * 1024):.1f}MB ({percent:.1f}%)")
87
+
88
+ size_mb = model_path.stat().st_size / (1024 * 1024)
89
+ logger.info(f"EfficientNet model downloaded successfully ({size_mb:.1f}MB)")
90
+ return True
91
+
92
+ except requests.exceptions.RequestException as e:
93
+ logger.error(f"Failed to download model: {e}")
94
+ # Clean up partial download
95
+ if model_path.exists():
96
+ model_path.unlink()
97
+ return False
98
+ except Exception as e:
99
+ logger.error(f"Error downloading model: {e}", exc_info=True)
100
+ # Clean up partial download
101
+ if model_path.exists():
102
+ model_path.unlink()
103
+ return False
104
+
105
+ if __name__ == "__main__":
106
+ # Configure logging
107
+ logging.basicConfig(
108
+ level=logging.INFO,
109
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
110
+ )
111
+ download_efficientnet_model()
112
+
scripts/efficient_caption.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import torch
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from transformers import AutoTokenizer
7
+ from efficient_train import Encoder, Decoder, ImageCaptioningModel, generate_caption
8
+ import os
9
+
10
+ # Configuration
11
+ MODEL_PATH = 'efficient_best_model.pth' # Path to your saved model
12
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ MAX_SEQ_LENGTH = 64 # Ensure this matches the value used during training
14
+
15
+ # Image transformation (ensure it matches the preprocessing used during training)
16
+ transform = transforms.Compose([
17
+ transforms.Resize(224),
18
+ transforms.CenterCrop(224),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
21
+ ])
22
+
23
+ # Load the tokenizer
24
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+ special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
27
+ tokenizer.add_special_tokens(special_tokens)
28
+
29
+ # Initialize the model components
30
+ encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
31
+ decoder = Decoder(
32
+ vocab_size=len(tokenizer),
33
+ embed_dim=512,
34
+ num_layers=8,
35
+ num_heads=8,
36
+ max_seq_length=MAX_SEQ_LENGTH
37
+ )
38
+ model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
39
+
40
+ # Load the trained model weights
41
+ if not os.path.exists(MODEL_PATH):
42
+ raise FileNotFoundError(f"Model file not found at: {MODEL_PATH}. Please ensure you have a trained model checkpoint at this location.")
43
+
44
+
45
+ # Add a check for the size of the file
46
+ if os.path.getsize(MODEL_PATH) == 0:
47
+ raise ValueError(f"Model file at {MODEL_PATH} is empty. Please check the saved model.")
48
+
49
+
50
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
51
+
52
+ # Check if the checkpoint has the model_state key
53
+ if 'model_state' not in checkpoint:
54
+ raise KeyError("The checkpoint file does not contain the key 'model_state'. Please ensure the model was saved correctly using 'torch.save(model.state_dict(), path)'.")
55
+
56
+
57
+ model.load_state_dict(checkpoint['model_state'])
58
+ model.eval()
59
+
60
+
61
+
62
+ def caption(image_path):
63
+
64
+ # Load and preprocess the image
65
+ image = Image.open(image_path).convert('RGB')
66
+ image = transform(image).to(DEVICE)
67
+
68
+ # Generate caption
69
+ caption1 = generate_caption(model, image, tokenizer, DEVICE, max_length=MAX_SEQ_LENGTH)
70
+ return caption1
71
+
72
+ if __name__ == '__main__':
73
+ parser = argparse.ArgumentParser(description="Generate a caption for the provided image.")
74
+ parser.add_argument('--image_dir', type=str, required=True, help="Path to the input image file")
75
+ args = parser.parse_args()
76
+
77
+ try:
78
+ result = caption(args.image_dir)
79
+ print(result)
80
+ except Exception as e:
81
+ logging.error(f"Error generating caption: {str(e)}")
82
+ exit(1)
scripts/optimize_models.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Optimization Script for Production Deployment
3
+ Reduces model size and improves inference speed through:
4
+ 1. Quantization (INT8)
5
+ 2. TorchScript compilation
6
+ 3. Model pruning (optional)
7
+ 4. State dict optimization
8
+ """
9
+
10
+ import torch
11
+ import os
12
+ import argparse
13
+ from pathlib import Path
14
+
15
+ # Import model classes BEFORE loading checkpoints (needed for unpickling)
16
+ # This ensures PyTorch can find the class definitions when loading saved objects
17
+ # Note: resnet_train.py has module-level code that loads COCO data, which may fail
18
+ # if training files aren't present. We'll handle this in the functions.
19
+
20
+ def quantize_model(checkpoint_path, output_path, model_type='resnet'):
21
+ """
22
+ Quantize model to INT8 for 4x size reduction and faster inference.
23
+ Note: Slight accuracy loss (usually <1%)
24
+ """
25
+ print(f"Quantizing {model_type} model...")
26
+
27
+ device = torch.device('cpu') # Quantization typically done on CPU
28
+
29
+ # Import classes before loading (required for unpickling)
30
+ # resnet_train.py now handles missing training data gracefully
31
+ if model_type == 'resnet':
32
+ # Import the module itself so we can update vocab later
33
+ import resnet_train
34
+ from resnet_train import EncoderCNN, DecoderRNN, Vocabulary
35
+
36
+ # Make Vocabulary available in __main__ for unpickling
37
+ # This handles cases where checkpoint was saved with Vocabulary from __main__
38
+ import __main__
39
+ if not hasattr(__main__, 'Vocabulary'):
40
+ __main__.Vocabulary = Vocabulary
41
+ elif model_type == 'efficientnet':
42
+ from efficient_train import Encoder, Decoder, ImageCaptioningModel
43
+ from transformers import AutoTokenizer
44
+
45
+ # Load checkpoint (now all classes are available for unpickling)
46
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
47
+
48
+ if model_type == 'resnet':
49
+ # For ResNet, quantize encoder and decoder separately
50
+
51
+ # IMPORTANT: Update vocab from checkpoint before creating DecoderRNN
52
+ # The decoder uses len(vocab.word2idx) in its __init__, so we need the full vocab
53
+ if 'vocab' in checkpoint and checkpoint['vocab'] is not None:
54
+ # Update the vocab in resnet_train module (DecoderRNN.__init__ references resnet_train.vocab)
55
+ resnet_train.vocab = checkpoint['vocab']
56
+ print(f" Updated vocab size: {len(checkpoint['vocab'].word2idx)}")
57
+ else:
58
+ raise ValueError("Checkpoint does not contain 'vocab' key. Cannot proceed.")
59
+
60
+ encoder = EncoderCNN()
61
+ decoder = DecoderRNN() # Now uses the correct vocab size from checkpoint
62
+
63
+ encoder.load_state_dict(checkpoint['encoder'])
64
+ decoder.load_state_dict(checkpoint['decoder'])
65
+
66
+ # Set to eval mode
67
+ encoder.eval()
68
+ decoder.eval()
69
+
70
+ # Prepare for quantization (dummy input)
71
+ dummy_input = torch.randn(1, 3, 224, 224)
72
+
73
+ # Quantize encoder (only Linear and Conv2d layers)
74
+ encoder_quantized = torch.quantization.quantize_dynamic(
75
+ encoder, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
76
+ )
77
+
78
+ # Quantize decoder (only Linear layers - Embedding requires special config)
79
+ # Embeddings are typically small and don't benefit much from quantization
80
+ decoder_quantized = torch.quantization.quantize_dynamic(
81
+ decoder, {torch.nn.Linear}, dtype=torch.qint8
82
+ )
83
+
84
+ # Save quantized model
85
+ quantized_checkpoint = {
86
+ 'encoder': encoder_quantized.state_dict(),
87
+ 'decoder': decoder_quantized.state_dict(),
88
+ 'vocab': checkpoint.get('vocab'),
89
+ 'quantized': True
90
+ }
91
+
92
+ elif model_type == 'efficientnet':
93
+ # Classes already imported above before loading checkpoint
94
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
95
+ tokenizer.pad_token = tokenizer.eos_token
96
+ special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
97
+ tokenizer.add_special_tokens(special_tokens)
98
+
99
+ encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
100
+ decoder = Decoder(
101
+ vocab_size=len(tokenizer),
102
+ embed_dim=512,
103
+ num_layers=8,
104
+ num_heads=8,
105
+ max_seq_length=64
106
+ )
107
+ model = ImageCaptioningModel(encoder, decoder)
108
+
109
+ # Load state dict - handle both 'model_state' key and direct state dict
110
+ if 'model_state' in checkpoint:
111
+ model.load_state_dict(checkpoint['model_state'])
112
+ else:
113
+ model.load_state_dict(checkpoint)
114
+
115
+ model.eval()
116
+
117
+ # Quantize the full model
118
+ model_quantized = torch.quantization.quantize_dynamic(
119
+ model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
120
+ )
121
+
122
+ quantized_checkpoint = {
123
+ 'model_state': model_quantized.state_dict(),
124
+ 'quantized': True
125
+ }
126
+
127
+ torch.save(quantized_checkpoint, output_path)
128
+
129
+ # Compare sizes
130
+ original_size = os.path.getsize(checkpoint_path) / (1024 * 1024) # MB
131
+ quantized_size = os.path.getsize(output_path) / (1024 * 1024) # MB
132
+ reduction = (1 - quantized_size / original_size) * 100
133
+
134
+ print(f"✓ Quantization complete!")
135
+ print(f" Original size: {original_size:.2f} MB")
136
+ print(f" Quantized size: {quantized_size:.2f} MB")
137
+ print(f" Size reduction: {reduction:.1f}%")
138
+
139
+ return output_path
140
+
141
+
142
+ def optimize_state_dict(checkpoint_path, output_path):
143
+ """
144
+ Remove unnecessary metadata and optimize state dict for smaller size.
145
+ """
146
+ print(f"Optimizing state dict...")
147
+
148
+ # Import classes before loading (required for unpickling)
149
+ try:
150
+ from resnet_train import Vocabulary
151
+ # Make Vocabulary available in __main__ for unpickling
152
+ import __main__
153
+ if not hasattr(__main__, 'Vocabulary'):
154
+ __main__.Vocabulary = Vocabulary
155
+ except ImportError:
156
+ pass
157
+
158
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
159
+
160
+ # Create optimized checkpoint with only essential data
161
+ optimized = {}
162
+ for key, value in checkpoint.items():
163
+ if key not in ['optimizer', 'scheduler', 'epoch', 'loss', 'metrics']:
164
+ optimized[key] = value
165
+
166
+ # Save with highest compression
167
+ torch.save(optimized, output_path, _use_new_zipfile_serialization=True)
168
+
169
+ original_size = os.path.getsize(checkpoint_path) / (1024 * 1024)
170
+ optimized_size = os.path.getsize(output_path) / (1024 * 1024)
171
+ reduction = (1 - optimized_size / original_size) * 100
172
+
173
+ print(f"✓ State dict optimized!")
174
+ print(f" Original: {original_size:.2f} MB")
175
+ print(f" Optimized: {optimized_size:.2f} MB")
176
+ print(f" Reduction: {reduction:.1f}%")
177
+
178
+ return output_path
179
+
180
+
181
+ def create_torchscript(checkpoint_path, output_path, model_type='resnet'):
182
+ """
183
+ Convert model to TorchScript for faster loading and inference.
184
+ Note: Requires example input for tracing.
185
+ """
186
+ print(f"Creating TorchScript model...")
187
+
188
+ device = torch.device('cpu')
189
+
190
+ # Import classes before loading (required for unpickling)
191
+ if model_type == 'resnet':
192
+ import resnet_train
193
+ from resnet_train import EncoderCNN, DecoderRNN, Vocabulary
194
+
195
+ # Make Vocabulary available in __main__ for unpickling
196
+ import __main__
197
+ if not hasattr(__main__, 'Vocabulary'):
198
+ __main__.Vocabulary = Vocabulary
199
+ elif model_type == 'efficientnet':
200
+ from efficient_train import Encoder, Decoder, ImageCaptioningModel
201
+ from transformers import AutoTokenizer
202
+
203
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
204
+
205
+ if model_type == 'resnet':
206
+ # Update vocab from checkpoint before creating DecoderRNN
207
+ if 'vocab' in checkpoint and checkpoint['vocab'] is not None:
208
+ resnet_train.vocab = checkpoint['vocab']
209
+ print(f" Updated vocab size: {len(checkpoint['vocab'].word2idx)}")
210
+ else:
211
+ raise ValueError("Checkpoint does not contain 'vocab' key. Cannot proceed.")
212
+
213
+ encoder = EncoderCNN().eval()
214
+ decoder = DecoderRNN().eval() # Now uses the correct vocab size
215
+
216
+ encoder.load_state_dict(checkpoint['encoder'])
217
+ decoder.load_state_dict(checkpoint['decoder'])
218
+
219
+ # Trace encoder
220
+ dummy_image = torch.randn(1, 3, 224, 224)
221
+ encoder_traced = torch.jit.trace(encoder, dummy_image)
222
+
223
+ # For decoder, we need to trace with proper inputs
224
+ # This is more complex due to RNN structure
225
+ print(" ⚠ TorchScript for RNN decoder may require manual scripting")
226
+ print(" ✓ Encoder traced successfully")
227
+
228
+ torch.jit.save(encoder_traced, output_path.replace('.pth', '_encoder.pt'))
229
+
230
+ elif model_type == 'efficientnet':
231
+ # Classes already imported above
232
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
233
+ tokenizer.pad_token = tokenizer.eos_token
234
+ special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
235
+ tokenizer.add_special_tokens(special_tokens)
236
+
237
+ encoder = Encoder(model_name='efficientnet_b3', embed_dim=512)
238
+ decoder = Decoder(
239
+ vocab_size=len(tokenizer),
240
+ embed_dim=512,
241
+ num_layers=8,
242
+ num_heads=8,
243
+ max_seq_length=64
244
+ )
245
+ model = ImageCaptioningModel(encoder, decoder).eval()
246
+
247
+ model.load_state_dict(checkpoint['model_state'])
248
+
249
+ # Trace encoder only (decoder has dynamic inputs)
250
+ dummy_image = torch.randn(1, 3, 224, 224)
251
+ encoder_traced = torch.jit.trace(model.encoder, dummy_image)
252
+
253
+ torch.jit.save(encoder_traced, output_path.replace('.pth', '_encoder.pt'))
254
+ print(" ✓ Encoder traced successfully")
255
+
256
+ print(f"✓ TorchScript saved to {output_path}")
257
+ return output_path
258
+
259
+
260
+ def main():
261
+ parser = argparse.ArgumentParser(description='Optimize models for production deployment')
262
+ parser.add_argument('--model', type=str, choices=['resnet', 'efficientnet', 'both'],
263
+ default='both', help='Model to optimize')
264
+ parser.add_argument('--method', type=str, choices=['quantize', 'optimize', 'torchscript', 'all'],
265
+ default='all', help='Optimization method')
266
+ parser.add_argument('--resnet-path', type=str, default='resnet_best_model.pth',
267
+ help='Path to ResNet checkpoint')
268
+ parser.add_argument('--efficientnet-path', type=str, default='efficient_best_model.pth',
269
+ help='Path to EfficientNet checkpoint')
270
+ parser.add_argument('--output-dir', type=str, default='optimized_models',
271
+ help='Output directory for optimized models')
272
+
273
+ args = parser.parse_args()
274
+
275
+ # Create output directory
276
+ os.makedirs(args.output_dir, exist_ok=True)
277
+
278
+ models_to_process = []
279
+ if args.model in ['resnet', 'both']:
280
+ if os.path.exists(args.resnet_path):
281
+ models_to_process.append(('resnet', args.resnet_path))
282
+ else:
283
+ print(f"⚠ Warning: {args.resnet_path} not found, skipping ResNet")
284
+
285
+ if args.model in ['efficientnet', 'both']:
286
+ if os.path.exists(args.efficientnet_path):
287
+ models_to_process.append(('efficientnet', args.efficientnet_path))
288
+ else:
289
+ print(f"⚠ Warning: {args.efficientnet_path} not found, skipping EfficientNet")
290
+
291
+ if not models_to_process:
292
+ print("❌ No models found to optimize!")
293
+ return
294
+
295
+ for model_type, model_path in models_to_process:
296
+ print(f"\n{'='*60}")
297
+ print(f"Processing {model_type.upper()} model")
298
+ print(f"{'='*60}")
299
+
300
+ base_name = Path(model_path).stem
301
+ output_base = os.path.join(args.output_dir, f"{model_type}_{base_name}")
302
+
303
+ if args.method in ['quantize', 'all']:
304
+ quantized_path = f"{output_base}_quantized.pth"
305
+ quantize_model(model_path, quantized_path, model_type)
306
+
307
+ if args.method in ['optimize', 'all']:
308
+ optimized_path = f"{output_base}_optimized.pth"
309
+ optimize_state_dict(model_path, optimized_path)
310
+
311
+ if args.method in ['torchscript', 'all']:
312
+ torchscript_path = f"{output_base}_torchscript.pt"
313
+ create_torchscript(model_path, torchscript_path, model_type)
314
+
315
+ print(f"\n{'='*60}")
316
+ print("✓ Optimization complete!")
317
+ print(f"Optimized models saved to: {args.output_dir}")
318
+ print(f"{'='*60}")
319
+
320
+
321
+ if __name__ == '__main__':
322
+ main()
323
+
scripts/resnet_caption.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import argparse
3
+ import torch
4
+ from PIL import Image
5
+ import nltk
6
+ nltk.download('punkt', quiet=True)
7
+
8
+ # Import the necessary components from resnet_train.py
9
+ from resnet_train import EncoderCNN, DecoderRNN, visualize_attention, CONFIG, Vocabulary
10
+ import resnet_train # To update its global vocab variable
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser(description="Generate image caption from a trained model.")
14
+ parser.add_argument("--image", type=str, required=True, help="Path to the input image")
15
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained model checkpoint")
16
+ args = parser.parse_args()
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Load checkpoint
21
+ checkpoint = torch.load(args.checkpoint, map_location=device,weights_only=False)
22
+
23
+ # Initialize models
24
+ encoder = EncoderCNN().to(device)
25
+ decoder = DecoderRNN().to(device)
26
+
27
+ # Load state dictionaries
28
+ encoder.load_state_dict(checkpoint['encoder'])
29
+ decoder.load_state_dict(checkpoint['decoder'])
30
+
31
+ # Update the global vocabulary from the checkpoint
32
+ resnet_train.vocab = checkpoint['vocab']
33
+
34
+ # Generate caption using the provided image path
35
+ caption = visualize_attention(args.image, encoder, decoder, device)
36
+ print(caption)
37
+
38
+ if __name__ == "__main__":
39
+ main()
static/css/custom.css ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .card {
2
+ border-radius: 1rem;
3
+ box-shadow: 0 0.5rem 1rem rgba(0, 0, 0, 0.15);
4
+ }
5
+
6
+ .card-header {
7
+ border-top-left-radius: 1rem !important;
8
+ border-top-right-radius: 1rem !important;
9
+ background-color: var(--bs-dark);
10
+ }
11
+
12
+ #previewImage {
13
+ max-height: 400px;
14
+ width: auto;
15
+ object-fit: contain;
16
+ }
17
+
18
+ .form-check {
19
+ margin-bottom: 0.5rem;
20
+ }
21
+
22
+ .alert {
23
+ margin-bottom: 0;
24
+ }
25
+
26
+ .btn-primary {
27
+ padding: 0.5rem 1.5rem;
28
+ }
29
+
30
+ /* Custom upload button styling */
31
+ .upload-container {
32
+ position: relative;
33
+ width: 120px;
34
+ height: 42px;
35
+ margin: 0 auto;
36
+ }
37
+
38
+ .upload-container input[type="file"] {
39
+ display: none;
40
+ }
41
+
42
+ .upload-app {
43
+ display: block;
44
+ position: relative;
45
+ width: 120px;
46
+ height: 42px;
47
+ transition: 0.3s ease width;
48
+ cursor: pointer;
49
+ }
50
+
51
+ .upload-btn {
52
+ position: absolute;
53
+ top: 0;
54
+ right: 0;
55
+ bottom: 0;
56
+ left: 0;
57
+ background-color: var(--bs-dark);
58
+ border: 2px solid var(--bs-border-color);
59
+ border-radius: 0.375rem;
60
+ overflow: hidden;
61
+ }
62
+
63
+ .upload-btn:before {
64
+ content: "Upload";
65
+ position: absolute;
66
+ top: 50%;
67
+ left: 45%;
68
+ transform: translate(-50%, -50%);
69
+ color: var(--bs-body-color);
70
+ font-size: 14px;
71
+ font-weight: bold;
72
+ transition: opacity 0.3s ease;
73
+ }
74
+
75
+ .file-selected .upload-btn:before {
76
+ opacity: 0;
77
+ }
78
+
79
+ .upload-arrow {
80
+ position: absolute;
81
+ top: 0;
82
+ right: 0;
83
+ width: 38px;
84
+ height: 38px;
85
+ background-color: var(--bs-dark);
86
+ transition: opacity 0.3s ease;
87
+ }
88
+
89
+ .file-selected .upload-arrow {
90
+ opacity: 0;
91
+ }
92
+
93
+ .upload-arrow:before,
94
+ .upload-arrow:after {
95
+ content: "";
96
+ position: absolute;
97
+ top: 18px;
98
+ width: 10px;
99
+ height: 2px;
100
+ background-color: var(--bs-body-color);
101
+ }
102
+
103
+ .upload-arrow:before {
104
+ right: 17px;
105
+ transform: rotateZ(-45deg);
106
+ }
107
+
108
+ .upload-arrow:after {
109
+ right: 11px;
110
+ transform: rotateZ(45deg);
111
+ }
112
+
113
+ .upload-success {
114
+ position: absolute;
115
+ top: 50%;
116
+ left: 50%;
117
+ width: 24px;
118
+ height: 24px;
119
+ margin: 0;
120
+ background-color: var(--bs-success);
121
+ transform: translate(-50%, -50%) scale(0);
122
+ border-radius: 50%;
123
+ opacity: 0;
124
+ transition: transform 0.3s ease, opacity 0.3s ease;
125
+ }
126
+
127
+ .upload-success i {
128
+ font-size: 16px;
129
+ color: #fff;
130
+ position: absolute;
131
+ top: 50%;
132
+ left: 50%;
133
+ transform: translate(-50%, -50%) scale(0);
134
+ transition: transform 0.3s ease 0.1s;
135
+ }
136
+
137
+ .file-selected .upload-success {
138
+ transform: translate(-50%, -50%) scale(1);
139
+ opacity: 1;
140
+ }
141
+
142
+ .file-selected .upload-success i {
143
+ transform: translate(-50%, -50%) scale(1);
144
+ }
static/js/main.js ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ document.addEventListener('DOMContentLoaded', function() {
2
+ const form = document.getElementById('uploadForm');
3
+ const imageInput = document.getElementById('imageInput');
4
+ const submitBtn = document.getElementById('submitBtn');
5
+ const spinner = submitBtn.querySelector('.spinner-border');
6
+ const resultSection = document.getElementById('resultSection');
7
+ const previewImage = document.getElementById('previewImage');
8
+ const captionText = document.getElementById('captionText');
9
+ const errorAlert = document.getElementById('errorAlert');
10
+ const uploadApp = document.querySelector('.upload-app');
11
+
12
+ // Preview image when selected
13
+ imageInput.addEventListener('change', function(e) {
14
+ const file = e.target.files[0];
15
+ if (file) {
16
+ const reader = new FileReader();
17
+ reader.onload = function(e) {
18
+ previewImage.src = e.target.result;
19
+ resultSection.classList.remove('d-none');
20
+ captionText.textContent = '';
21
+ errorAlert.classList.add('d-none');
22
+
23
+ // Add success animation class
24
+ uploadApp.classList.add('file-selected');
25
+ };
26
+ reader.readAsDataURL(file);
27
+ }
28
+ });
29
+
30
+ form.addEventListener('submit', async function(e) {
31
+ e.preventDefault();
32
+
33
+ const formData = new FormData();
34
+ const file = imageInput.files[0];
35
+
36
+ if (!file) {
37
+ showError('Please select an image first.');
38
+ return;
39
+ }
40
+
41
+ // Add file and selected model to form data
42
+ formData.append('image', file);
43
+ formData.append('model', document.querySelector('input[name="model"]:checked').value);
44
+
45
+ // Show loading state
46
+ setLoading(true);
47
+
48
+ try {
49
+ const response = await fetch('/upload', {
50
+ method: 'POST',
51
+ body: formData
52
+ });
53
+
54
+ const data = await response.json();
55
+
56
+ if (!response.ok) {
57
+ throw new Error(data.error || 'Failed to generate caption');
58
+ }
59
+
60
+ // Display the caption
61
+ captionText.textContent = data.caption;
62
+ resultSection.classList.remove('d-none');
63
+ errorAlert.classList.add('d-none');
64
+
65
+ } catch (error) {
66
+ showError(error.message || 'An error occurred while generating the caption');
67
+ } finally {
68
+ setLoading(false);
69
+ }
70
+ });
71
+
72
+ function setLoading(isLoading) {
73
+ submitBtn.disabled = isLoading;
74
+ spinner.classList.toggle('d-none', !isLoading);
75
+ submitBtn.textContent = isLoading ? ' Processing...' : 'Generate Caption';
76
+ if (isLoading) {
77
+ submitBtn.prepend(spinner);
78
+ }
79
+ }
80
+
81
+ function showError(message) {
82
+ errorAlert.textContent = message;
83
+ errorAlert.classList.remove('d-none');
84
+ }
85
+ });
templates/index.html ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en" data-bs-theme="dark">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Image Caption Generator</title>
7
+ <link href="https://cdn.replit.com/agent/bootstrap-agent-dark-theme.min.css" rel="stylesheet">
8
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.7.2/font/bootstrap-icons.css" rel="stylesheet">
9
+ <link href="{{ url_for('static', filename='css/custom.css') }}" rel="stylesheet">
10
+ </head>
11
+ <body>
12
+ <div class="container py-5">
13
+ <div class="row justify-content-center">
14
+ <div class="col-md-8">
15
+ <div class="card">
16
+ <div class="card-header">
17
+ <h2 class="text-center mb-0">Image Caption Generator</h2>
18
+ </div>
19
+ <div class="card-body">
20
+ <form id="uploadForm">
21
+ <div class="mb-4">
22
+ <label class="form-label">Select Model:</label>
23
+ <div class="form-check">
24
+ <input class="form-check-input" type="radio" name="model" id="efficientnet" value="efficientnet" checked>
25
+ <label class="form-check-label" for="efficientnet">
26
+ EfficientNet-B3
27
+ </label>
28
+ </div>
29
+ </div>
30
+
31
+ <div class="mb-4">
32
+ <label class="form-label d-block text-center">Upload Image:</label>
33
+ <div class="upload-container">
34
+ <label class="upload-app">
35
+ <input type="file" id="imageInput" accept="image/png,image/jpeg,image/jpg" required>
36
+ <div class="upload-btn">
37
+ <div class="upload-arrow"></div>
38
+ <div class="upload-success">
39
+ <i class="bi bi-check"></i>
40
+ </div>
41
+ </div>
42
+ </label>
43
+ </div>
44
+ </div>
45
+
46
+ <div class="text-center">
47
+ <button type="submit" class="btn btn-primary" id="submitBtn">
48
+ <span class="spinner-border spinner-border-sm d-none" role="status" aria-hidden="true"></span>
49
+ Generate Caption
50
+ </button>
51
+ </div>
52
+ </form>
53
+
54
+ <div id="resultSection" class="mt-4 d-none">
55
+ <div class="text-center">
56
+ <img id="previewImage" class="img-fluid mb-3 rounded" alt="Uploaded image">
57
+ <div id="captionText" class="alert alert-info"></div>
58
+ </div>
59
+ </div>
60
+
61
+ <div id="errorAlert" class="alert alert-danger mt-3 d-none"></div>
62
+ </div>
63
+ </div>
64
+ </div>
65
+ </div>
66
+ </div>
67
+
68
+ <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
69
+ <script src="{{ url_for('static', filename='js/main.js') }}"></script>
70
+ </body>
71
+ </html>
training/__pycache__/efficient_train.cpython-314.pyc ADDED
Binary file (25.9 kB). View file
 
training/__pycache__/resnet_train.cpython-314.pyc ADDED
Binary file (33.4 kB). View file
 
training/efficient_train.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import Dataset, DataLoader, random_split
9
+ from torchvision import transforms
10
+ from timm import create_model
11
+ from transformers import AutoTokenizer
12
+ from pycocotools.coco import COCO
13
+ from datetime import datetime
14
+ from PIL import Image
15
+
16
+ # Distributed training imports
17
+ import torch.distributed as dist
18
+ from torch.nn.parallel import DistributedDataParallel as DDP
19
+
20
+ # ------------------- DDP Setup Functions ------------------- #
21
+ def setup_distributed():
22
+ dist.init_process_group(backend='nccl')
23
+
24
+ def cleanup_distributed():
25
+ dist.destroy_process_group()
26
+
27
+ # ------------------- Configuration and Constants ------------------- #
28
+ DEFAULT_MAX_SEQ_LENGTH = 64
29
+ DEFAULT_EMBED_DIM = 512
30
+ DEFAULT_NUM_LAYERS = 8
31
+ DEFAULT_NUM_HEADS = 8
32
+
33
+ # ------------------- Data Preparation ------------------- #
34
+ class CocoCaptionDataset(Dataset):
35
+ """Custom COCO dataset that returns image-caption pairs with processing"""
36
+ def __init__(self, root, ann_file, transform=None, max_seq_length=DEFAULT_MAX_SEQ_LENGTH):
37
+ self.coco = COCO(ann_file)
38
+ self.root = root
39
+ self.transform = transform
40
+ self.max_seq_length = max_seq_length
41
+ self.ids = list(self.coco.imgs.keys())
42
+
43
+ # Initialize tokenizer with special tokens
44
+ self.tokenizer = AutoTokenizer.from_pretrained('gpt2')
45
+ self.tokenizer.pad_token = self.tokenizer.eos_token
46
+ special_tokens = {'additional_special_tokens': ['<start>', '<end>']}
47
+ self.tokenizer.add_special_tokens(special_tokens)
48
+ self.vocab_size = len(self.tokenizer)
49
+
50
+ def __len__(self):
51
+ return len(self.ids)
52
+
53
+ def __getitem__(self, idx):
54
+ img_id = self.ids[idx]
55
+ img_info = self.coco.loadImgs(img_id)[0]
56
+ img_path = os.path.join(self.root, img_info['file_name'])
57
+ img = Image.open(img_path).convert('RGB')
58
+
59
+ # Get random caption from available annotations
60
+ ann_ids = self.coco.getAnnIds(imgIds=img_id)
61
+ anns = self.coco.loadAnns(ann_ids)
62
+ caption = random.choice(anns)['caption']
63
+
64
+ # Apply transforms
65
+ if self.transform:
66
+ img = self.transform(img)
67
+
68
+ # Tokenize caption with special tokens
69
+ caption = f"<start> {caption} <end>"
70
+ inputs = self.tokenizer(
71
+ caption,
72
+ padding='max_length',
73
+ max_length=self.max_seq_length,
74
+ truncation=True,
75
+ return_tensors='pt',
76
+ )
77
+ return img, inputs.input_ids.squeeze(0)
78
+
79
+ class CocoTestDataset(Dataset):
80
+ """COCO test dataset that loads images only (no annotations available)"""
81
+ def __init__(self, root, transform=None):
82
+ self.root = root
83
+ self.transform = transform
84
+ # Assumes all files in the directory are images
85
+ self.img_files = sorted(os.listdir(root))
86
+
87
+ def __len__(self):
88
+ return len(self.img_files)
89
+
90
+ def __getitem__(self, idx):
91
+ img_file = self.img_files[idx]
92
+ img_path = os.path.join(self.root, img_file)
93
+ img = Image.open(img_path).convert('RGB')
94
+ if self.transform:
95
+ img = self.transform(img)
96
+ return img, img_file # Return the filename for reference
97
+
98
+ # ------------------- Model Architecture ------------------- #
99
+ class Encoder(nn.Module):
100
+ """CNN encoder using timm models"""
101
+ def __init__(self, model_name='efficientnet_b3', embed_dim=DEFAULT_EMBED_DIM):
102
+ super().__init__()
103
+ self.backbone = create_model(
104
+ model_name,
105
+ pretrained=True,
106
+ num_classes=0,
107
+ global_pool='',
108
+ features_only=False
109
+ )
110
+
111
+ # Get output channels from backbone
112
+ with torch.no_grad():
113
+ dummy = torch.randn(1, 3, 224, 224)
114
+ features = self.backbone(dummy)
115
+ in_features = features.shape[1]
116
+
117
+ self.projection = nn.Linear(in_features, embed_dim)
118
+
119
+ def forward(self, x):
120
+ features = self.backbone(x) # (batch, channels, height, width)
121
+ batch_size, channels, height, width = features.shape
122
+ features = features.permute(0, 2, 3, 1).reshape(batch_size, -1, channels)
123
+ return self.projection(features)
124
+
125
+ class Decoder(nn.Module):
126
+ """Transformer decoder with positional embeddings and causal masking"""
127
+ def __init__(self, vocab_size, embed_dim, num_layers, num_heads, max_seq_length, dropout=0.1):
128
+ super().__init__()
129
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
130
+ self.positional_encoding = nn.Embedding(max_seq_length, embed_dim)
131
+ self.dropout = nn.Dropout(dropout)
132
+
133
+ decoder_layer = nn.TransformerDecoderLayer(
134
+ d_model=embed_dim,
135
+ nhead=num_heads,
136
+ dropout=dropout,
137
+ batch_first=False
138
+ )
139
+ self.layers = nn.TransformerDecoder(decoder_layer, num_layers)
140
+ self.fc = nn.Linear(embed_dim, vocab_size)
141
+ self.max_seq_length = max_seq_length
142
+
143
+ # Register causal mask buffer
144
+ self.register_buffer(
145
+ "causal_mask",
146
+ torch.triu(torch.full((max_seq_length, max_seq_length), float('-inf')), diagonal=1)
147
+ )
148
+
149
+ def forward(self, x, memory, tgt_mask=None):
150
+ seq_length = x.size(1)
151
+ positions = torch.arange(0, seq_length, device=x.device).unsqueeze(0)
152
+ x_emb = self.embedding(x) + self.positional_encoding(positions)
153
+ x_emb = self.dropout(x_emb)
154
+
155
+ # Reshape for transformer: (seq, batch, features)
156
+ x_emb = x_emb.permute(1, 0, 2)
157
+ memory = memory.permute(1, 0, 2)
158
+
159
+ # Apply causal mask
160
+ mask = self.causal_mask[:seq_length, :seq_length]
161
+ output = self.layers(
162
+ x_emb,
163
+ memory,
164
+ tgt_mask=mask
165
+ )
166
+ return self.fc(output.permute(1, 0, 2))
167
+
168
+ class ImageCaptioningModel(nn.Module):
169
+ """Complete image captioning model"""
170
+ def __init__(self, encoder, decoder):
171
+ super().__init__()
172
+ self.encoder = encoder
173
+ self.decoder = decoder
174
+
175
+ def forward(self, images, captions, tgt_mask=None):
176
+ memory = self.encoder(images)
177
+ return self.decoder(captions, memory)
178
+
179
+ # ------------------- Inference Utility ------------------- #
180
+ def generate_caption(model, image, tokenizer, device, max_length=DEFAULT_MAX_SEQ_LENGTH):
181
+ """
182
+ Generate a caption for a single image using greedy decoding.
183
+ Assumes the tokenizer has '<start>' and '<end>' as special tokens.
184
+ """
185
+ model.eval()
186
+ with torch.no_grad():
187
+ image = image.unsqueeze(0) # shape: (1, 3, H, W)
188
+ if isinstance(model, DDP):
189
+ memory = model.module.encoder(image)
190
+ else:
191
+ memory = model.encoder(image)
192
+ start_token = tokenizer.convert_tokens_to_ids("<start>")
193
+ end_token = tokenizer.convert_tokens_to_ids("<end>")
194
+ caption_ids = [start_token]
195
+ for _ in range(max_length - 1):
196
+ decoder_input = torch.tensor(caption_ids, device=device).unsqueeze(0)
197
+ if isinstance(model, DDP):
198
+ output = model.module.decoder(decoder_input, memory)
199
+ else:
200
+ output = model.decoder(decoder_input, memory)
201
+ next_token_logits = output[0, -1, :]
202
+ next_token = next_token_logits.argmax().item()
203
+ caption_ids.append(next_token)
204
+ if next_token == end_token:
205
+ break
206
+ caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
207
+ return caption_text
208
+
209
+ # ------------------- Training Utilities ------------------- #
210
+ def create_dataloaders(args):
211
+ """Create train/val/test dataloaders with appropriate transforms"""
212
+ train_transform = transforms.Compose([
213
+ transforms.Resize(256),
214
+ transforms.RandomCrop(224),
215
+ transforms.RandomHorizontalFlip(),
216
+ transforms.ToTensor(),
217
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
218
+ ])
219
+
220
+ eval_transform = transforms.Compose([
221
+ transforms.Resize(224),
222
+ transforms.CenterCrop(224),
223
+ transforms.ToTensor(),
224
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
225
+ ])
226
+
227
+ # Load datasets
228
+ train_set = CocoCaptionDataset(
229
+ root=args.train_image_dir,
230
+ ann_file=args.train_ann_file,
231
+ transform=train_transform
232
+ )
233
+
234
+ val_set = CocoCaptionDataset(
235
+ root=args.val_image_dir,
236
+ ann_file=args.val_ann_file,
237
+ transform=eval_transform
238
+ )
239
+
240
+ test_set = CocoTestDataset(
241
+ root=args.test_image_dir,
242
+ transform=eval_transform
243
+ )
244
+
245
+ # For distributed training, use DistributedSampler
246
+ if args.distributed:
247
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
248
+ else:
249
+ train_sampler = None
250
+
251
+ # Optimize for GPU: use pin_memory and more workers if CUDA is available
252
+ pin_memory = torch.cuda.is_available()
253
+ num_workers = 8 if torch.cuda.is_available() else 4 # More workers for GPU
254
+ persistent_workers = torch.cuda.is_available() # Keep workers alive between epochs
255
+
256
+ train_loader = DataLoader(
257
+ train_set,
258
+ batch_size=args.batch_size,
259
+ shuffle=(train_sampler is None),
260
+ sampler=train_sampler,
261
+ num_workers=num_workers,
262
+ pin_memory=pin_memory,
263
+ persistent_workers=persistent_workers,
264
+ prefetch_factor=2 if num_workers > 0 else None # Prefetch batches
265
+ )
266
+ val_loader = DataLoader(
267
+ val_set,
268
+ batch_size=args.batch_size,
269
+ shuffle=False,
270
+ num_workers=num_workers,
271
+ pin_memory=pin_memory,
272
+ persistent_workers=persistent_workers
273
+ )
274
+ test_loader = DataLoader(
275
+ test_set,
276
+ batch_size=1, # For inference, process one image at a time
277
+ shuffle=False,
278
+ num_workers=num_workers
279
+ )
280
+
281
+ return train_loader, val_loader, test_loader, train_set.tokenizer, train_set
282
+
283
+ def train_epoch(model, loader, optimizer, criterion, scaler, scheduler, device, args):
284
+ model.train()
285
+ total_loss = 0.0
286
+ if args.distributed:
287
+ loader.sampler.set_epoch(args.epoch)
288
+ for batch_idx, (images, captions) in enumerate(loader):
289
+ images = images.to(device)
290
+ captions = captions.to(device)
291
+
292
+ # Teacher forcing: use shifted captions as decoder input
293
+ decoder_input = captions[:, :-1]
294
+ targets = captions[:, 1:].contiguous()
295
+
296
+ optimizer.zero_grad()
297
+
298
+ # Use new API for PyTorch 2.6+
299
+ if hasattr(torch.amp, 'autocast'):
300
+ autocast_context = torch.amp.autocast('cuda', enabled=args.use_amp)
301
+ else:
302
+ autocast_context = torch.cuda.amp.autocast(enabled=args.use_amp)
303
+
304
+ with autocast_context:
305
+ logits = model(images, decoder_input)
306
+ loss = criterion(
307
+ logits.view(-1, logits.size(-1)),
308
+ targets.view(-1)
309
+ )
310
+
311
+ scaler.scale(loss).backward()
312
+ if (batch_idx + 1) % args.grad_accum == 0:
313
+ scaler.step(optimizer)
314
+ scaler.update()
315
+ # Only step scheduler if it's provided and supports per-step updates
316
+ if scheduler is not None:
317
+ scheduler.step() # Update learning rate
318
+ optimizer.zero_grad()
319
+
320
+ total_loss += loss.item()
321
+
322
+ return total_loss / len(loader)
323
+
324
+ def validate(model, loader, criterion, device):
325
+ model.eval()
326
+ total_loss = 0.0
327
+ with torch.no_grad():
328
+ for images, captions in loader:
329
+ images = images.to(device)
330
+ captions = captions.to(device)
331
+ decoder_input = captions[:, :-1]
332
+ targets = captions[:, 1:].contiguous()
333
+
334
+ logits = model(images, decoder_input)
335
+ loss = criterion(
336
+ logits.view(-1, logits.size(-1)),
337
+ targets.view(-1)
338
+ )
339
+ total_loss += loss.item()
340
+
341
+ return total_loss / len(loader)
342
+
343
+ def main(args):
344
+ if args.distributed:
345
+ setup_distributed()
346
+
347
+ device = torch.device("cuda", args.local_rank) if args.distributed else torch.device("cuda" if torch.cuda.is_available() else "cpu")
348
+
349
+ torch.manual_seed(args.seed)
350
+ random.seed(args.seed)
351
+ np.random.seed(args.seed)
352
+
353
+ # Create dataloaders and obtain tokenizer and training dataset (for sampler)
354
+ train_loader, val_loader, test_loader, tokenizer, train_set = create_dataloaders(args)
355
+
356
+ # Initialize model
357
+ encoder = Encoder(args.model_name, args.embed_dim)
358
+ decoder = Decoder(
359
+ vocab_size=tokenizer.vocab_size + 2,
360
+ embed_dim=args.embed_dim,
361
+ num_layers=args.num_layers,
362
+ num_heads=args.num_heads,
363
+ max_seq_length=DEFAULT_MAX_SEQ_LENGTH,
364
+ dropout=0.1
365
+ )
366
+ model = ImageCaptioningModel(encoder, decoder).to(device)
367
+
368
+ if args.distributed:
369
+ model = DDP(model, device_ids=[args.local_rank])
370
+
371
+ # Set up training components
372
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
373
+ criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
374
+ # Use new API for PyTorch 2.6+
375
+ if hasattr(torch.amp, 'GradScaler'):
376
+ scaler = torch.amp.GradScaler('cuda', enabled=args.use_amp)
377
+ else:
378
+ scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
379
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
380
+ optimizer,
381
+ T_max=args.epochs * len(train_loader),
382
+ eta_min=1e-6
383
+ )
384
+ best_val_loss = float('inf')
385
+ patience_counter = 0
386
+
387
+ # Support resume training
388
+ start_epoch = 0
389
+ if args.resume_checkpoint:
390
+ # Handle PyTorch 2.6+ security: allow tokenizer classes
391
+ try:
392
+ from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
393
+ torch.serialization.add_safe_globals([GPT2TokenizerFast])
394
+ except ImportError:
395
+ pass
396
+
397
+ # Load checkpoint (weights_only=False for backward compatibility with tokenizer)
398
+ checkpoint = torch.load(args.resume_checkpoint, map_location=device, weights_only=False)
399
+ if args.distributed:
400
+ model.module.load_state_dict(checkpoint['model_state'])
401
+ else:
402
+ model.load_state_dict(checkpoint['model_state'])
403
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
404
+ start_epoch = checkpoint['epoch'] + 1
405
+ best_val_loss = checkpoint.get('val_loss', best_val_loss)
406
+ print(f"Resumed training from epoch {start_epoch}")
407
+
408
+ # Training loop
409
+ for epoch in range(start_epoch, args.epochs):
410
+ args.epoch = epoch # Useful for the sampler in distributed training
411
+ if args.distributed:
412
+ train_loader.sampler.set_epoch(epoch)
413
+ if args.local_rank == 0 or not args.distributed:
414
+ print(f"Epoch {epoch+1}/{args.epochs}")
415
+ train_loss = train_epoch(
416
+ model, train_loader, optimizer, criterion, scaler, scheduler, device, args
417
+ )
418
+ val_loss = validate(model, val_loader, criterion, device)
419
+ if args.local_rank == 0 or not args.distributed:
420
+ print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
421
+
422
+ # Checkpointing
423
+ if val_loss < best_val_loss:
424
+ best_val_loss = val_loss
425
+ patience_counter = 0
426
+ torch.save({
427
+ 'epoch': epoch,
428
+ 'model_state': model.module.state_dict() if args.distributed else model.state_dict(),
429
+ 'optimizer_state': optimizer.state_dict(),
430
+ 'scheduler_state': scheduler.state_dict(),
431
+ 'val_loss': val_loss,
432
+ 'tokenizer': tokenizer,
433
+ }, os.path.join(args.checkpoint_dir, 'best_model.pth'))
434
+ else:
435
+ patience_counter += 1
436
+
437
+ if patience_counter >= args.early_stopping_patience:
438
+ print("Early stopping triggered")
439
+ break
440
+
441
+ # Inference on test set
442
+ if args.local_rank == 0 or not args.distributed:
443
+ print("\nGenerating captions on test set images:")
444
+ model.eval()
445
+ for idx, (image, filename) in enumerate(test_loader):
446
+ image = image.to(device).squeeze(0)
447
+ caption = generate_caption(model, image, tokenizer, device)
448
+ print(f"{filename}: {caption}")
449
+ if idx >= 4:
450
+ break
451
+
452
+ if args.distributed:
453
+ cleanup_distributed()
454
+
455
+
456
+ if __name__ == "__main__":
457
+ parser = argparse.ArgumentParser()
458
+ # Data arguments
459
+ parser.add_argument('--train_image_dir', type=str, required=True)
460
+ parser.add_argument('--train_ann_file', type=str, required=True)
461
+ parser.add_argument('--val_image_dir', type=str, required=True)
462
+ parser.add_argument('--val_ann_file', type=str, required=True)
463
+ parser.add_argument('--test_image_dir', type=str, required=True) # Test set images only
464
+
465
+ # Model arguments
466
+ parser.add_argument('--model_name', type=str, default='efficientnet_b3')
467
+ parser.add_argument('--embed_dim', type=int, default=DEFAULT_EMBED_DIM)
468
+ parser.add_argument('--num_layers', type=int, default=DEFAULT_NUM_LAYERS)
469
+ parser.add_argument('--num_heads', type=int, default=DEFAULT_NUM_HEADS)
470
+
471
+ # Training arguments
472
+ parser.add_argument('--batch_size', type=int, default=96)
473
+ parser.add_argument('--lr', type=float, default=3e-4)
474
+ parser.add_argument('--epochs', type=int, default=10)
475
+ parser.add_argument('--seed', type=int, default=42)
476
+ parser.add_argument('--use_amp', action='store_true')
477
+ parser.add_argument('--grad_accum', type=int, default=1)
478
+ parser.add_argument('--checkpoint_dir', type=str, default='/workspace')
479
+ parser.add_argument('--early_stopping_patience', type=int, default=3)
480
+
481
+ # Distributed training arguments
482
+ # Accept both --local_rank and --local-rank
483
+ parser.add_argument('--local_rank', '--local-rank', type=int, default=0,
484
+ help="Local rank. Necessary for using distributed training.")
485
+ parser.add_argument('--distributed', action='store_true', help="Use distributed training")
486
+
487
+ # Resume training argument
488
+ parser.add_argument('--resume_checkpoint', type=str, default=None, help="Path to checkpoint to resume training from.")
489
+
490
+ args = parser.parse_args()
491
+
492
+ # Override local_rank from environment variable if set
493
+ if "LOCAL_RANK" in os.environ:
494
+ args.local_rank = int(os.environ["LOCAL_RANK"])
495
+
496
+ # Create checkpoint directory
497
+ os.makedirs(args.checkpoint_dir, exist_ok=True)
498
+
499
+ main(args)
training/hyperparameter_tuning.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter Optimization using Optuna
3
+ Run this to find the best hyperparameters for your model
4
+ """
5
+
6
+ import optuna
7
+ import torch
8
+ import argparse
9
+ import os
10
+ import sys
11
+ from efficient_train import create_dataloaders, Encoder, Decoder, ImageCaptioningModel
12
+ from efficient_train import train_epoch, validate, generate_caption
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
16
+
17
+ def train_with_config(trial, args):
18
+ """Train model with suggested hyperparameters from Optuna"""
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # Suggest hyperparameters
23
+ lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
24
+ batch_size = trial.suggest_categorical('batch_size', [32, 64, 96, 128])
25
+ embed_dim = trial.suggest_categorical('embed_dim', [256, 512, 768])
26
+ num_layers = trial.suggest_int('num_layers', 4, 12)
27
+ num_heads = trial.suggest_categorical('num_heads', [4, 8, 12, 16])
28
+ dropout = trial.suggest_uniform('dropout', 0.1, 0.5)
29
+ weight_decay = trial.suggest_loguniform('weight_decay', 1e-5, 1e-2)
30
+ warmup_epochs = trial.suggest_int('warmup_epochs', 0, 3)
31
+
32
+ # Update args with suggested values
33
+ args.lr = lr
34
+ args.batch_size = batch_size
35
+ args.embed_dim = embed_dim
36
+ args.num_layers = num_layers
37
+ args.num_heads = num_heads
38
+ args.epochs = 5 # Fewer epochs for hyperparameter search
39
+
40
+ # Create dataloaders
41
+ train_loader, val_loader, test_loader, tokenizer, train_set = create_dataloaders(args)
42
+
43
+ # Initialize model
44
+ encoder = Encoder(args.model_name, embed_dim)
45
+ decoder = Decoder(
46
+ vocab_size=tokenizer.vocab_size + 2,
47
+ embed_dim=embed_dim,
48
+ num_layers=num_layers,
49
+ num_heads=num_heads,
50
+ max_seq_length=64,
51
+ dropout=dropout
52
+ )
53
+ model = ImageCaptioningModel(encoder, decoder).to(device)
54
+
55
+ # Optimizer
56
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
57
+
58
+ # Scheduler
59
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
60
+
61
+ # Loss
62
+ criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
63
+
64
+ # Mixed precision
65
+ scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
66
+
67
+ # Training loop (fewer epochs for hyperparameter search)
68
+ best_val_loss = float('inf')
69
+
70
+ for epoch in range(args.epochs):
71
+ # Train
72
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler,
73
+ scheduler, device, args)
74
+
75
+ # Validate
76
+ val_loss = validate(model, val_loader, criterion, device)
77
+
78
+ # Update scheduler
79
+ scheduler.step(val_loss)
80
+
81
+ # Report to Optuna
82
+ trial.report(val_loss, epoch)
83
+
84
+ # Prune trial if not promising
85
+ if trial.should_prune():
86
+ raise optuna.exceptions.TrialPruned()
87
+
88
+ if val_loss < best_val_loss:
89
+ best_val_loss = val_loss
90
+
91
+ return best_val_loss
92
+
93
+
94
+ def objective(trial):
95
+ """Optuna objective function"""
96
+
97
+ # Create minimal args object
98
+ args = argparse.Namespace(
99
+ train_image_dir='Data/train2017/train2017',
100
+ train_ann_file='Data/annotations_trainval2017/annotations/captions_train2017.json',
101
+ val_image_dir='Data/val2017',
102
+ val_ann_file='Data/annotations_trainval2017/annotations/captions_val2017.json',
103
+ test_image_dir='Data/test2017/test2017',
104
+ model_name='efficientnet_b3',
105
+ embed_dim=512, # Will be overridden
106
+ num_layers=8, # Will be overridden
107
+ num_heads=8, # Will be overridden
108
+ batch_size=96, # Will be overridden
109
+ lr=3e-4, # Will be overridden
110
+ epochs=5,
111
+ seed=42,
112
+ use_amp=True,
113
+ grad_accum=1,
114
+ checkpoint_dir='checkpoints',
115
+ early_stopping_patience=3,
116
+ distributed=False,
117
+ local_rank=0,
118
+ resume_checkpoint=None
119
+ )
120
+
121
+ try:
122
+ val_loss = train_with_config(trial, args)
123
+ return val_loss
124
+ except Exception as e:
125
+ print(f"Trial failed: {e}")
126
+ return float('inf')
127
+
128
+
129
+ def main():
130
+ parser = argparse.ArgumentParser(description='Hyperparameter optimization with Optuna')
131
+ parser.add_argument('--n_trials', type=int, default=50, help='Number of trials')
132
+ parser.add_argument('--timeout', type=int, default=3600*24, help='Timeout in seconds')
133
+ parser.add_argument('--study_name', type=str, default='efficientnet_captioning',
134
+ help='Study name')
135
+ parser.add_argument('--storage', type=str, default='sqlite:///optuna_study.db',
136
+ help='Storage URL for study')
137
+
138
+ args = parser.parse_args()
139
+
140
+ # Create or load study
141
+ study = optuna.create_study(
142
+ direction='minimize',
143
+ study_name=args.study_name,
144
+ storage=args.storage,
145
+ load_if_exists=True,
146
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=3)
147
+ )
148
+
149
+ print(f"Starting optimization with {args.n_trials} trials...")
150
+ print(f"Study: {args.study_name}")
151
+
152
+ # Optimize
153
+ study.optimize(objective, n_trials=args.n_trials, timeout=args.timeout)
154
+
155
+ # Print results
156
+ print("\n" + "="*60)
157
+ print("Optimization Complete!")
158
+ print("="*60)
159
+ print(f"Best trial: {study.best_trial.number}")
160
+ print(f"Best validation loss: {study.best_value:.4f}")
161
+ print("\nBest parameters:")
162
+ for key, value in study.best_params.items():
163
+ print(f" {key}: {value}")
164
+
165
+ # Save results
166
+ import json
167
+ with open('best_hyperparameters.json', 'w') as f:
168
+ json.dump(study.best_params, f, indent=2)
169
+
170
+ print("\nBest hyperparameters saved to best_hyperparameters.json")
171
+
172
+ # Visualize (optional, requires plotly)
173
+ try:
174
+ import optuna.visualization as vis
175
+
176
+ # Optimization history
177
+ fig = vis.plot_optimization_history(study)
178
+ fig.write_image("optimization_history.png")
179
+ print("Saved optimization_history.png")
180
+
181
+ # Parameter importances
182
+ fig = vis.plot_param_importances(study)
183
+ fig.write_image("param_importances.png")
184
+ print("Saved param_importances.png")
185
+
186
+ # Parallel coordinate plot
187
+ fig = vis.plot_parallel_coordinate(study)
188
+ fig.write_image("parallel_coordinate.png")
189
+ print("Saved parallel_coordinate.png")
190
+
191
+ except ImportError:
192
+ print("Install plotly to generate visualizations: pip install plotly")
193
+
194
+
195
+ if __name__ == '__main__':
196
+ main()
197
+
training/resnet_train.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import json
4
+ import torch
5
+ import nltk
6
+ import numpy as np
7
+ from PIL import Image
8
+ from pycocotools.coco import COCO
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torchvision import transforms
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from collections import Counter
14
+ import matplotlib.pyplot as plt
15
+ from torchvision import models
16
+ from tqdm import tqdm
17
+ import torch.distributed as dist
18
+ import argparse
19
+
20
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
21
+ from nltk.translate.meteor_score import meteor_score
22
+
23
+ # Additional imports for extended metrics
24
+ from rouge import Rouge
25
+ from pycocoevalcap.cider.cider import Cider
26
+
27
+ nltk.download('punkt', quiet=True)
28
+ nltk.download('punkt_tab', quiet=True)
29
+ nltk.download('wordnet', quiet=True)
30
+
31
+ # ===========================
32
+ # CONFIGURATION
33
+ # ===========================
34
+ CONFIG = {
35
+ # Paths
36
+ "train_ann": r"B:/!S3/Computer Vision/Project/annotations/captions_train2017.json",
37
+ "val_ann": r"B:/!S3/Computer Vision/Project/annotations/captions_val2017.json",
38
+ "train_img_dir": "images/train2017",
39
+ "val_img_dir": "images/val2017",
40
+
41
+ # Model
42
+ "img_size": 224,
43
+ "embed_size": 256,
44
+ "hidden_size": 512,
45
+ "attention_dim": 512,
46
+ "feature_map_size": 14, # From ResNet feature maps
47
+ "dropout": 0.5, # Dropout probability added
48
+
49
+ # Training
50
+ "batch_size": 176,
51
+ "num_epochs": 30,
52
+ "lr": 0.005,
53
+ "fine_tune_encoder": True,
54
+ "grad_clip": 5.0,
55
+
56
+ # Vocabulary
57
+ "vocab_threshold": 5,
58
+ "max_len": 20,
59
+
60
+ # Beam search
61
+ "beam_size": 3
62
+ }
63
+
64
+ # ===========================
65
+ # Vocabulary Builder
66
+ # ===========================
67
+ class Vocabulary:
68
+ def __init__(self):
69
+ self.word2idx = {}
70
+ self.idx2word = {}
71
+ self.idx = 0
72
+
73
+ def build(self, coco, threshold):
74
+ counter = Counter()
75
+ ids = list(coco.anns.keys())
76
+ for ann_id in tqdm(ids):
77
+ caption = coco.anns[ann_id]['caption']
78
+ tokens = nltk.word_tokenize(caption.lower())
79
+ counter.update(tokens)
80
+ # Add special tokens
81
+ self.add_word('<pad>')
82
+ self.add_word('<start>')
83
+ self.add_word('<end>')
84
+ self.add_word('<unk>')
85
+ # Add words meeting threshold
86
+ for word, cnt in counter.items():
87
+ if cnt >= threshold:
88
+ self.add_word(word)
89
+
90
+ def add_word(self, word):
91
+ if word not in self.word2idx:
92
+ self.word2idx[word] = self.idx
93
+ self.idx2word[self.idx] = word
94
+ self.idx += 1
95
+
96
+ # Initialize vocab with full training data (only if training data exists)
97
+ # This allows the module to be imported for inference without training data
98
+ vocab = Vocabulary()
99
+ # Always add special tokens (needed for DecoderRNN class definition)
100
+ vocab.add_word('<pad>')
101
+ vocab.add_word('<start>')
102
+ vocab.add_word('<end>')
103
+ vocab.add_word('<unk>')
104
+
105
+ if os.path.exists(CONFIG['train_ann']):
106
+ try:
107
+ coco_train = COCO(CONFIG['train_ann'])
108
+ vocab.build(coco_train, CONFIG['vocab_threshold'])
109
+ print(f"Vocabulary size: {len(vocab.word2idx)}")
110
+ except (FileNotFoundError, OSError) as e:
111
+ # Training data not available - vocab will be loaded from checkpoint
112
+ # Keep minimal vocab with special tokens for class definition
113
+ print(f"Warning: Could not load training data. Vocabulary will be loaded from checkpoint.")
114
+ else:
115
+ # Training data path doesn't exist - keep minimal vocab for inference
116
+ print(f"Warning: Training data not found at {CONFIG['train_ann']}. Vocabulary will be loaded from checkpoint.")
117
+
118
+
119
+ # ===========================
120
+ # Attention-based Model
121
+ # ===========================
122
+ class EncoderCNN(nn.Module):
123
+ def __init__(self):
124
+ super().__init__()
125
+ # Use the new weights parameter instead of the deprecated 'pretrained'
126
+ from torchvision.models import resnet50, ResNet50_Weights
127
+ weights = ResNet50_Weights.IMAGENET1K_V1
128
+ resnet = resnet50(weights=weights)
129
+ modules = list(resnet.children())[:-2]
130
+ self.cnn = nn.Sequential(*modules)
131
+ self.adaptive_pool = nn.AdaptiveAvgPool2d((CONFIG['feature_map_size'], CONFIG['feature_map_size']))
132
+ if not CONFIG['fine_tune_encoder']:
133
+ for param in self.cnn.parameters():
134
+ param.requires_grad = False
135
+
136
+ def forward(self, x):
137
+ features = self.cnn(x) # (batch, 2048, H, W)
138
+ features = self.adaptive_pool(features) # (batch, 2048, 14, 14)
139
+ features = features.permute(0, 2, 3, 1) # (batch, 14, 14, 2048)
140
+ features = features.view(features.size(0), -1, features.size(-1)) # (batch, 196, 2048)
141
+ return features
142
+
143
+ class Attention(nn.Module):
144
+ def __init__(self):
145
+ super().__init__()
146
+ self.U = nn.Linear(CONFIG['hidden_size'], CONFIG['attention_dim'])
147
+ self.W = nn.Linear(2048, CONFIG['attention_dim'])
148
+ self.v = nn.Linear(CONFIG['attention_dim'], 1)
149
+ self.tanh = nn.Tanh()
150
+ self.softmax = nn.Softmax(dim=1)
151
+
152
+ def forward(self, features, hidden):
153
+ U_h = self.U(hidden).unsqueeze(1) # (batch, 1, attention_dim)
154
+ W_s = self.W(features) # (batch, 196, attention_dim)
155
+ att = self.tanh(W_s + U_h) # (batch, 196, attention_dim)
156
+ e = self.v(att).squeeze(2) # (batch, 196)
157
+ alpha = self.softmax(e) # (batch, 196)
158
+ context = (features * alpha.unsqueeze(2)).sum(dim=1) # (batch, 2048)
159
+ return context, alpha
160
+
161
+ class DecoderRNN(nn.Module):
162
+ def __init__(self):
163
+ super().__init__()
164
+ self.embed = nn.Embedding(len(vocab.word2idx), CONFIG['embed_size'])
165
+ self.lstm = nn.LSTM(CONFIG['embed_size'] + 2048,
166
+ CONFIG['hidden_size'], batch_first=True)
167
+ self.attention = Attention()
168
+ self.fc = nn.Linear(CONFIG['hidden_size'], len(vocab.word2idx))
169
+ self.dropout = nn.Dropout(p=CONFIG['dropout'])
170
+
171
+ def forward(self, features, captions, teacher_forcing_ratio=0.5):
172
+ batch_size = features.size(0)
173
+ h, c = self.init_hidden(features)
174
+ seq_length = captions.size(1) - 1
175
+ outputs = torch.zeros(batch_size, seq_length, len(vocab.word2idx)).to(features.device)
176
+ embeddings = self.dropout(self.embed(captions[:, 0]))
177
+ for t in range(seq_length):
178
+ context, alpha = self.attention(features, h.squeeze(0))
179
+ lstm_input = torch.cat([embeddings, context], dim=1).unsqueeze(1)
180
+ out, (h, c) = self.lstm(lstm_input, (h, c))
181
+ out = self.dropout(out)
182
+ output = self.fc(out.squeeze(1))
183
+ outputs[:, t] = output
184
+ use_teacher_forcing = np.random.random() < teacher_forcing_ratio
185
+ if use_teacher_forcing and t < seq_length - 1:
186
+ embeddings = self.dropout(self.embed(captions[:, t+1]))
187
+ else:
188
+ embeddings = self.dropout(self.embed(output.argmax(dim=-1)))
189
+ return outputs
190
+
191
+ def init_hidden(self, features):
192
+ h = torch.zeros(1, features.size(0), CONFIG['hidden_size']).to(features.device)
193
+ c = torch.zeros(1, features.size(0), CONFIG['hidden_size']).to(features.device)
194
+ return h, c
195
+
196
+ # ===========================
197
+ # Enhanced Dataset Class
198
+ # ===========================
199
+ class CocoDataset(Dataset):
200
+ def __init__(self, ann_file, img_dir, vocab, transform=None):
201
+ self.coco = COCO(ann_file)
202
+ self.img_dir = img_dir
203
+ self.vocab = vocab
204
+ self.transform = transform or self.default_transform()
205
+ all_ids = list(self.coco.anns.keys())
206
+ valid_ids = []
207
+ for ann_id in all_ids:
208
+ ann = self.coco.anns[ann_id]
209
+ img_id = ann['image_id']
210
+ file_name = self.coco.loadImgs(img_id)[0]['file_name']
211
+ img_path = os.path.join(self.img_dir, file_name)
212
+ if os.path.exists(img_path):
213
+ valid_ids.append(ann_id)
214
+ else:
215
+ print(f"Warning: File {img_path} not found. Skipping annotation id {ann_id}.")
216
+ self.ids = valid_ids
217
+
218
+ def default_transform(self):
219
+ return transforms.Compose([
220
+ transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
221
+ transforms.ToTensor(),
222
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
223
+ std=[0.229, 0.224, 0.225])
224
+ ])
225
+
226
+ def __len__(self):
227
+ return len(self.ids)
228
+
229
+ def __getitem__(self, idx):
230
+ ann_id = self.ids[idx]
231
+ ann = self.coco.anns[ann_id]
232
+ img_id = ann['image_id']
233
+ img_info = self.coco.loadImgs(img_id)[0]
234
+ img_path = os.path.join(self.img_dir, img_info['file_name'])
235
+ img = Image.open(img_path).convert('RGB')
236
+ img = self.transform(img)
237
+ caption = ann['caption']
238
+ tokens = ['<start>'] + nltk.word_tokenize(caption.lower()) + ['<end>']
239
+ caption_ids = [self.vocab.word2idx.get(token, self.vocab.word2idx['<unk>']) for token in tokens]
240
+ caption_ids += [self.vocab.word2idx['<pad>']] * (CONFIG['max_len'] - len(caption_ids))
241
+ caption_ids = caption_ids[:CONFIG['max_len']]
242
+ return img, torch.tensor(caption_ids)
243
+
244
+ # ===========================
245
+ # Distributed Setup Functions
246
+ # ===========================
247
+ def setup_distributed():
248
+ dist.init_process_group(backend='nccl')
249
+
250
+ def cleanup_distributed():
251
+ dist.destroy_process_group()
252
+
253
+ # ===========================
254
+ # Training & Evaluation
255
+ # ===========================
256
+ def evaluate(encoder, decoder, loader, device, criterion, compute_extended=False):
257
+ encoder.eval()
258
+ decoder.eval()
259
+ total_loss = 0
260
+ # Instantiate smoothing function for BLEU score.
261
+ smoothing_fn = SmoothingFunction().method1
262
+ if compute_extended:
263
+ bleu_scores = []
264
+ meteor_scores = []
265
+ rouge = Rouge()
266
+ rouge1_scores = []
267
+ rougeL_scores = []
268
+ cider_scorer = Cider()
269
+ ref_dict = {}
270
+ hyp_dict = {}
271
+ sample_id = 0
272
+ with torch.no_grad():
273
+ for imgs, caps in loader:
274
+ imgs = imgs.to(device)
275
+ caps = caps.to(device)
276
+ features = encoder(imgs)
277
+ outputs = decoder(features, caps, teacher_forcing_ratio=0)
278
+ loss = criterion(outputs.view(-1, len(vocab.word2idx)), caps[:, 1:].reshape(-1))
279
+ total_loss += loss.item()
280
+ for i in range(imgs.size(0)):
281
+ predicted_ids = beam_search(features[i].unsqueeze(0), decoder, device)
282
+ predicted_caption = [vocab.idx2word[idx] for idx in predicted_ids
283
+ if idx not in [vocab.word2idx['<start>'], vocab.word2idx['<end>'], vocab.word2idx['<pad>']]]
284
+ reference_ids = caps[i].tolist()
285
+ reference_caption = [vocab.idx2word[idx] for idx in reference_ids
286
+ if idx not in [vocab.word2idx['<start>'], vocab.word2idx['<end>'], vocab.word2idx['<pad>']]]
287
+ bleu = sentence_bleu([reference_caption], predicted_caption, smoothing_function=smoothing_fn)
288
+ bleu_scores.append(bleu)
289
+ meteor = meteor_score([reference_caption], predicted_caption)
290
+ meteor_scores.append(meteor)
291
+ pred_str = " ".join(predicted_caption)
292
+ ref_str = " ".join(reference_caption)
293
+ rouge_scores = rouge.get_scores(pred_str, ref_str)
294
+ rouge1_scores.append(rouge_scores[0]['rouge-1']['f'])
295
+ rougeL_scores.append(rouge_scores[0]['rouge-l']['f'])
296
+ ref_dict[sample_id] = [ref_str]
297
+ hyp_dict[sample_id] = [pred_str]
298
+ sample_id += 1
299
+ avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
300
+ avg_meteor = sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0
301
+ avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores) if rouge1_scores else 0
302
+ avg_rougeL = sum(rougeL_scores) / len(rougeL_scores) if rougeL_scores else 0
303
+ cider_score, _ = cider_scorer.compute_score(ref_dict, hyp_dict)
304
+ metrics = {'BLEU': avg_bleu, 'METEOR': avg_meteor,
305
+ 'ROUGE-1': avg_rouge1, 'ROUGE-L': avg_rougeL, 'CIDEr': cider_score}
306
+ if dist.is_initialized() and dist.get_rank() == 0:
307
+ print(f"Extended Metrics: {metrics}")
308
+ return total_loss / len(loader), metrics
309
+ else:
310
+ with torch.no_grad():
311
+ for imgs, caps in loader:
312
+ imgs = imgs.to(device)
313
+ caps = caps.to(device)
314
+ features = encoder(imgs)
315
+ outputs = decoder(features, caps, teacher_forcing_ratio=0)
316
+ loss = criterion(outputs.view(-1, len(vocab.word2idx)), caps[:, 1:].reshape(-1))
317
+ total_loss += loss.item()
318
+ return total_loss / len(loader)
319
+
320
+ def beam_search(features, decoder, device):
321
+ k = CONFIG['beam_size']
322
+ start_token = vocab.word2idx['<start>']
323
+ h, c = (decoder.module.init_hidden(features) if isinstance(decoder, torch.nn.parallel.DistributedDataParallel)
324
+ else decoder.init_hidden(features))
325
+ sequences = [[[start_token], 0.0, h, c]]
326
+ for _ in range(CONFIG['max_len'] - 1):
327
+ all_candidates = []
328
+ for seq in sequences:
329
+ tokens, score, h, c = seq
330
+ if tokens[-1] == vocab.word2idx['<end>']:
331
+ all_candidates.append(seq)
332
+ continue
333
+ input_tensor = torch.LongTensor([tokens[-1]]).to(device)
334
+ if isinstance(decoder, torch.nn.parallel.DistributedDataParallel):
335
+ context, _ = decoder.module.attention(features, h.squeeze(0))
336
+ emb = decoder.module.embed(input_tensor)
337
+ lstm_input = torch.cat([emb, context], dim=1).unsqueeze(1)
338
+ out, (h, c) = decoder.module.lstm(lstm_input, (h, c))
339
+ output = decoder.module.fc(out.squeeze(1))
340
+ else:
341
+ context, _ = decoder.attention(features, h.squeeze(0))
342
+ emb = decoder.embed(input_tensor)
343
+ lstm_input = torch.cat([emb, context], dim=1).unsqueeze(1)
344
+ out, (h, c) = decoder.lstm(lstm_input, (h, c))
345
+ output = decoder.fc(out.squeeze(1))
346
+ log_probs = torch.log_softmax(output, dim=1)
347
+ top_probs, top_indices = log_probs.topk(k)
348
+ for i in range(k):
349
+ token = top_indices[0][i].item()
350
+ new_score = score + top_probs[0][i].item()
351
+ new_seq = tokens + [token]
352
+ all_candidates.append([new_seq, new_score, h, c])
353
+ ordered = sorted(all_candidates, key=lambda x: x[1] / len(x[0]), reverse=True)
354
+ sequences = ordered[:k]
355
+ return sequences[0][0]
356
+
357
+ def visualize_attention(image_path, encoder, decoder, device):
358
+ img = Image.open(image_path).convert('RGB')
359
+ transform = transforms.Compose([
360
+ transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
361
+ transforms.ToTensor(),
362
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
363
+ ])
364
+ img_tensor = transform(img).unsqueeze(0).to(device)
365
+ encoder.eval()
366
+ decoder.eval()
367
+ with torch.no_grad():
368
+ features = encoder(img_tensor)
369
+ caption_ids = beam_search(features, decoder, device)
370
+ caption = [vocab.idx2word[idx] for idx in caption_ids
371
+ if idx not in [vocab.word2idx['<start>'], vocab.word2idx['<end>'], vocab.word2idx['<pad>']]]
372
+ return ' '.join(caption)
373
+
374
+ def train(distributed=False, local_rank=0, device=torch.device('cpu'), resume_checkpoint=None):
375
+ train_set = CocoDataset(CONFIG['train_ann'], CONFIG['train_img_dir'], vocab)
376
+ val_set = CocoDataset(CONFIG['val_ann'], CONFIG['val_img_dir'], vocab)
377
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) if distributed else None
378
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_set, shuffle=False) if distributed else None
379
+ train_loader = DataLoader(train_set,
380
+ batch_size=CONFIG['batch_size'],
381
+ shuffle=(train_sampler is None),
382
+ sampler=train_sampler,
383
+ num_workers=8)
384
+ val_loader = DataLoader(val_set,
385
+ batch_size=CONFIG['batch_size'],
386
+ sampler=val_sampler,
387
+ num_workers=8)
388
+ encoder = EncoderCNN().to(device)
389
+ decoder = DecoderRNN().to(device)
390
+ if distributed:
391
+ encoder = torch.nn.parallel.DistributedDataParallel(encoder, device_ids=[local_rank], output_device=local_rank)
392
+ decoder = torch.nn.parallel.DistributedDataParallel(decoder, device_ids=[local_rank], output_device=local_rank)
393
+ criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx['<pad>'])
394
+ if CONFIG['fine_tune_encoder']:
395
+ params = list(decoder.parameters()) + list(encoder.parameters())
396
+ else:
397
+ params = list(decoder.parameters())
398
+ optimizer = optim.Adam(params, lr=CONFIG['lr'])
399
+ # Initialize training state variables
400
+ start_epoch = 0
401
+ best_val_loss = float('inf')
402
+ epochs_without_improvement = 0
403
+ # Resume from checkpoint if provided
404
+ if resume_checkpoint is not None:
405
+ print(f"Loading checkpoint from {resume_checkpoint}")
406
+ # Allow Vocabulary as a safe global so it can be unpickled
407
+ torch.serialization.add_safe_globals([Vocabulary])
408
+ checkpoint = torch.load(resume_checkpoint, map_location=device, weights_only=False)
409
+ encoder.load_state_dict(checkpoint['encoder'])
410
+ decoder.load_state_dict(checkpoint['decoder'])
411
+ if 'optimizer' in checkpoint:
412
+ optimizer.load_state_dict(checkpoint['optimizer'])
413
+ else:
414
+ print("Warning: 'optimizer' state not found in checkpoint. Starting with fresh optimizer state.")
415
+ start_epoch = checkpoint['epoch'] + 1
416
+ best_val_loss = checkpoint.get('best_val_loss', float('inf'))
417
+ epochs_without_improvement = checkpoint.get('epochs_without_improvement', 0)
418
+ print(f"Resumed training from epoch {start_epoch}")
419
+ for epoch in range(start_epoch, CONFIG['num_epochs']):
420
+ if distributed:
421
+ train_sampler.set_epoch(epoch)
422
+ encoder.train()
423
+ decoder.train()
424
+ total_loss = 0
425
+ for imgs, caps in tqdm(train_loader):
426
+ imgs = imgs.to(device)
427
+ caps = caps.to(device)
428
+ optimizer.zero_grad()
429
+ features = encoder(imgs)
430
+ outputs = decoder(features, caps)
431
+ loss = criterion(outputs.view(-1, len(vocab.word2idx)),
432
+ caps[:, 1:].reshape(-1))
433
+ loss.backward()
434
+ if CONFIG['grad_clip'] is not None:
435
+ nn.utils.clip_grad_norm_(decoder.parameters(), CONFIG['grad_clip'])
436
+ optimizer.step()
437
+ total_loss += loss.item()
438
+ if epoch % 5 == 0:
439
+ val_loss, metrics = evaluate(encoder, decoder, val_loader, device, criterion, compute_extended=True)
440
+ if local_rank == 0:
441
+ print(f"Epoch {epoch+1}/{CONFIG['num_epochs']} | Train Loss: {total_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f}")
442
+ with open("metrics_log_Resnet.txt", "a") as f:
443
+ f.write(f"Epoch {epoch+1}: {metrics}\n")
444
+ else:
445
+ val_loss = evaluate(encoder, decoder, val_loader, device, criterion, compute_extended=False)
446
+ if local_rank == 0:
447
+ print(f"Epoch {epoch+1}/{CONFIG['num_epochs']} | Train Loss: {total_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f}")
448
+ if local_rank == 0:
449
+ if val_loss < best_val_loss:
450
+ best_val_loss = val_loss
451
+ epochs_without_improvement = 0
452
+ checkpoint_path = f'caption_model_best_epoch{epoch}.pth'
453
+ torch.save({
454
+ 'epoch': epoch,
455
+ 'encoder': encoder.state_dict(),
456
+ 'decoder': decoder.state_dict(),
457
+ 'optimizer': optimizer.state_dict(),
458
+ 'best_val_loss': best_val_loss,
459
+ 'epochs_without_improvement': epochs_without_improvement,
460
+ 'vocab': vocab,
461
+ 'config': CONFIG
462
+ }, checkpoint_path)
463
+ #upload_files(epoch)
464
+ else:
465
+ epochs_without_improvement += 1
466
+ if epochs_without_improvement >= 3:
467
+ print("Early stopping triggered.")
468
+ break
469
+
470
+ def upload_files(i):
471
+ files = [f"caption_model_best_epoch{i}.pth", "metrics_log_Resnet.txt"]
472
+ for file in files:
473
+ result = subprocess.run(
474
+ ["rclone", "copy", file, "onedrive:/Computer_Viz/"],
475
+ capture_output=True, text=True
476
+ )
477
+ if result.returncode == 0:
478
+ print(f"{file} uploaded successfully.")
479
+ else:
480
+ print(f"Error during upload of {file}:", result.stderr)
481
+
482
+ if __name__ == '__main__':
483
+ parser = argparse.ArgumentParser()
484
+ parser.add_argument("--distributed", action="store_true", help="Enable distributed training")
485
+ parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume training")
486
+ args = parser.parse_args()
487
+ if args.distributed:
488
+ setup_distributed()
489
+ local_rank = int(os.environ['LOCAL_RANK'])
490
+ torch.cuda.set_device(local_rank)
491
+ device = torch.device("cuda", local_rank)
492
+ else:
493
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
494
+ local_rank = 0
495
+ train(distributed=args.distributed, local_rank=local_rank, device=device, resume_checkpoint=args.resume)
496
+ if args.distributed:
497
+ cleanup_distributed()
training/train_advanced.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced Training Script with Best Practices
3
+ - Learning rate scheduling
4
+ - Mixed precision training
5
+ - Experiment tracking (W&B optional)
6
+ - Comprehensive evaluation
7
+ - Model checkpointing
8
+ """
9
+
10
+ import argparse
11
+ import os
12
+ import random
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.optim as optim
17
+ from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, LambdaLR
18
+ import math
19
+ from efficient_train import (
20
+ create_dataloaders, Encoder, Decoder, ImageCaptioningModel,
21
+ train_epoch, validate, generate_caption
22
+ )
23
+ from datetime import datetime
24
+
25
+ # Optional: Weights & Biases
26
+ try:
27
+ import wandb
28
+ WANDB_AVAILABLE = True
29
+ except ImportError:
30
+ WANDB_AVAILABLE = False
31
+ print("W&B not available. Install with: pip install wandb")
32
+
33
+
34
+ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
35
+ """Create learning rate schedule with warmup and cosine annealing"""
36
+ def lr_lambda(current_step):
37
+ if current_step < num_warmup_steps:
38
+ return float(current_step) / float(max(1, num_warmup_steps))
39
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
40
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
41
+ return LambdaLR(optimizer, lr_lambda)
42
+
43
+
44
+ def train_advanced(args):
45
+ """Advanced training with all best practices"""
46
+
47
+ # Setup
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ torch.manual_seed(args.seed)
50
+ random.seed(args.seed)
51
+ np.random.seed(args.seed)
52
+
53
+ # GPU optimizations
54
+ if torch.cuda.is_available():
55
+ torch.backends.cudnn.benchmark = True # Optimize for consistent input sizes
56
+ torch.backends.cudnn.deterministic = False # Faster, but non-deterministic
57
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
58
+ print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
59
+
60
+ # Initialize W&B
61
+ if args.use_wandb and WANDB_AVAILABLE:
62
+ wandb.init(
63
+ project=args.wandb_project,
64
+ name=f"{args.model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
65
+ config=vars(args)
66
+ )
67
+
68
+ # Create dataloaders
69
+ train_loader, val_loader, test_loader, tokenizer, train_set = create_dataloaders(args)
70
+
71
+ # Initialize model
72
+ encoder = Encoder(args.model_name, args.embed_dim)
73
+ decoder = Decoder(
74
+ vocab_size=tokenizer.vocab_size + 2,
75
+ embed_dim=args.embed_dim,
76
+ num_layers=args.num_layers,
77
+ num_heads=args.num_heads,
78
+ max_seq_length=64,
79
+ dropout=args.dropout
80
+ )
81
+ model = ImageCaptioningModel(encoder, decoder).to(device)
82
+
83
+ # Resume from checkpoint if provided
84
+ start_epoch = 0
85
+ best_val_loss = float('inf')
86
+ best_metrics = {}
87
+
88
+ if args.resume_checkpoint:
89
+ print(f"Loading checkpoint from {args.resume_checkpoint}")
90
+ # Handle PyTorch 2.6+ security: allow tokenizer classes
91
+ try:
92
+ from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
93
+ torch.serialization.add_safe_globals([GPT2TokenizerFast])
94
+ except ImportError:
95
+ pass
96
+
97
+ checkpoint = torch.load(args.resume_checkpoint, map_location=device, weights_only=False)
98
+ model.load_state_dict(checkpoint['model_state'])
99
+ start_epoch = checkpoint.get('epoch', 0) + 1
100
+ best_val_loss = checkpoint.get('val_loss', float('inf'))
101
+ print(f"Resumed from epoch {start_epoch}, best val loss: {best_val_loss:.4f}")
102
+
103
+ # Optimizer with different learning rates for encoder/decoder
104
+ encoder_params = [p for n, p in model.named_parameters() if 'encoder' in n]
105
+ decoder_params = [p for n, p in model.named_parameters() if 'decoder' in n]
106
+
107
+ if args.different_lr:
108
+ # Lower learning rate for encoder (fine-tuning)
109
+ optimizer = optim.AdamW([
110
+ {'params': encoder_params, 'lr': args.lr * 0.1},
111
+ {'params': decoder_params, 'lr': args.lr}
112
+ ], weight_decay=args.weight_decay)
113
+ else:
114
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
115
+
116
+ # Learning rate scheduler
117
+ if args.scheduler == 'cosine':
118
+ scheduler = CosineAnnealingLR(
119
+ optimizer,
120
+ T_max=args.epochs * len(train_loader),
121
+ eta_min=args.min_lr
122
+ )
123
+ elif args.scheduler == 'plateau':
124
+ scheduler = ReduceLROnPlateau(
125
+ optimizer, mode='min', factor=0.5, patience=args.patience
126
+ )
127
+ elif args.scheduler == 'warmup_cosine':
128
+ num_training_steps = args.epochs * len(train_loader)
129
+ num_warmup_steps = args.warmup_epochs * len(train_loader)
130
+ scheduler = get_cosine_schedule_with_warmup(
131
+ optimizer, num_warmup_steps, num_training_steps
132
+ )
133
+ else:
134
+ scheduler = None
135
+
136
+ # Loss function
137
+ criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
138
+
139
+ # Mixed precision training - Use new API for PyTorch 2.6+
140
+ if hasattr(torch.amp, 'GradScaler'):
141
+ scaler = torch.amp.GradScaler('cuda', enabled=args.use_amp)
142
+ else:
143
+ scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
144
+
145
+ # Create checkpoint directory
146
+ os.makedirs(args.checkpoint_dir, exist_ok=True)
147
+
148
+ # Training loop
149
+ patience_counter = 0
150
+
151
+ for epoch in range(start_epoch, args.epochs):
152
+ args.epoch = epoch # Set epoch for train_epoch function
153
+ print(f"\nEpoch {epoch+1}/{args.epochs}")
154
+ print("-" * 60)
155
+
156
+ # Train
157
+ train_loss = train_epoch(
158
+ model, train_loader, optimizer, criterion, scaler,
159
+ scheduler if args.scheduler == 'cosine' or args.scheduler == 'warmup_cosine' else None,
160
+ device, args
161
+ )
162
+
163
+ # Validate
164
+ val_loss = validate(model, val_loader, criterion, device)
165
+
166
+ # Update scheduler
167
+ if args.scheduler == 'plateau':
168
+ scheduler.step(val_loss)
169
+ elif args.scheduler in ['cosine', 'warmup_cosine']:
170
+ # Already updated in train_epoch
171
+ pass
172
+
173
+ current_lr = optimizer.param_groups[0]['lr']
174
+
175
+ print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")
176
+
177
+ # Log to W&B
178
+ log_dict = {
179
+ 'epoch': epoch,
180
+ 'train_loss': train_loss,
181
+ 'val_loss': val_loss,
182
+ 'learning_rate': current_lr
183
+ }
184
+
185
+ if args.use_wandb and WANDB_AVAILABLE:
186
+ wandb.log(log_dict)
187
+
188
+ # Checkpointing
189
+ is_best = val_loss < best_val_loss
190
+
191
+ if is_best:
192
+ best_val_loss = val_loss
193
+ patience_counter = 0
194
+
195
+ # Save best model
196
+ checkpoint = {
197
+ 'epoch': epoch,
198
+ 'model_state': model.state_dict(),
199
+ 'optimizer_state': optimizer.state_dict(),
200
+ 'scheduler_state': scheduler.state_dict() if scheduler else None,
201
+ 'val_loss': val_loss,
202
+ 'train_loss': train_loss,
203
+ 'tokenizer': tokenizer,
204
+ 'config': vars(args)
205
+ }
206
+
207
+ best_path = os.path.join(args.checkpoint_dir, 'best_model.pth')
208
+ torch.save(checkpoint, best_path)
209
+ print(f"✓ Saved best model (val_loss: {val_loss:.4f})")
210
+
211
+ else:
212
+ patience_counter += 1
213
+
214
+ # Save periodic checkpoints
215
+ if (epoch + 1) % args.save_every == 0:
216
+ checkpoint = {
217
+ 'epoch': epoch,
218
+ 'model_state': model.state_dict(),
219
+ 'optimizer_state': optimizer.state_dict(),
220
+ 'scheduler_state': scheduler.state_dict() if scheduler else None,
221
+ 'val_loss': val_loss,
222
+ 'train_loss': train_loss,
223
+ 'tokenizer': tokenizer,
224
+ 'config': vars(args)
225
+ }
226
+ checkpoint_path = os.path.join(
227
+ args.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'
228
+ )
229
+ torch.save(checkpoint, checkpoint_path)
230
+ print(f"✓ Saved periodic checkpoint (epoch {epoch+1})")
231
+
232
+ # Early stopping
233
+ if patience_counter >= args.early_stopping_patience:
234
+ print(f"\nEarly stopping triggered after {args.early_stopping_patience} epochs without improvement")
235
+ break
236
+
237
+ print("\n" + "="*60)
238
+ print("Training Complete!")
239
+ print(f"Best validation loss: {best_val_loss:.4f}")
240
+ print(f"Best model saved to: {os.path.join(args.checkpoint_dir, 'best_model.pth')}")
241
+ print("="*60)
242
+
243
+ if args.use_wandb and WANDB_AVAILABLE:
244
+ wandb.finish()
245
+
246
+
247
+ def main():
248
+ parser = argparse.ArgumentParser(description='Advanced training with best practices')
249
+
250
+ # Data arguments
251
+ parser.add_argument('--train_image_dir', type=str, required=True)
252
+ parser.add_argument('--train_ann_file', type=str, required=True)
253
+ parser.add_argument('--val_image_dir', type=str, required=True)
254
+ parser.add_argument('--val_ann_file', type=str, required=True)
255
+ parser.add_argument('--test_image_dir', type=str, required=True)
256
+
257
+ # Model arguments
258
+ parser.add_argument('--model_name', type=str, default='efficientnet_b3')
259
+ parser.add_argument('--embed_dim', type=int, default=512)
260
+ parser.add_argument('--num_layers', type=int, default=8)
261
+ parser.add_argument('--num_heads', type=int, default=8)
262
+ parser.add_argument('--dropout', type=float, default=0.1)
263
+
264
+ # Training arguments
265
+ parser.add_argument('--batch_size', type=int, default=96)
266
+ parser.add_argument('--lr', type=float, default=3e-4)
267
+ parser.add_argument('--epochs', type=int, default=20)
268
+ parser.add_argument('--seed', type=int, default=42)
269
+ parser.add_argument('--use_amp', action='store_true', help='Use mixed precision')
270
+ parser.add_argument('--grad_accum', type=int, default=1)
271
+ parser.add_argument('--weight_decay', type=float, default=1e-4)
272
+ parser.add_argument('--different_lr', action='store_true',
273
+ help='Use different LR for encoder/decoder')
274
+
275
+ # Scheduler arguments
276
+ parser.add_argument('--scheduler', type=str, default='plateau',
277
+ choices=['cosine', 'plateau', 'warmup_cosine', 'none'])
278
+ parser.add_argument('--patience', type=int, default=3)
279
+ parser.add_argument('--min_lr', type=float, default=1e-6)
280
+ parser.add_argument('--warmup_epochs', type=int, default=2)
281
+
282
+ # Checkpointing
283
+ parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
284
+ parser.add_argument('--resume_checkpoint', type=str, default=None)
285
+ parser.add_argument('--save_every', type=int, default=5)
286
+ parser.add_argument('--early_stopping_patience', type=int, default=5)
287
+
288
+ # Experiment tracking
289
+ parser.add_argument('--use_wandb', action='store_true', help='Use Weights & Biases')
290
+ parser.add_argument('--wandb_project', type=str, default='image-captioning')
291
+
292
+ # Additional args needed by create_dataloaders and train_epoch
293
+ parser.add_argument('--distributed', action='store_true', help='Use distributed training')
294
+ parser.add_argument('--local_rank', type=int, default=0, help='Local rank for distributed training')
295
+
296
+ args = parser.parse_args()
297
+
298
+ # Set epoch attribute (will be updated during training)
299
+ args.epoch = 0
300
+
301
+ train_advanced(args)
302
+
303
+
304
+ if __name__ == '__main__':
305
+ main()
306
+