Spaces:
Running
Running
| """ | |
| Application routes. | |
| """ | |
| import os | |
| import logging | |
| import time | |
| from datetime import datetime | |
| from flask import Blueprint, render_template, request, jsonify | |
| from werkzeug.utils import secure_filename | |
| from torchvision import transforms | |
| from PIL import Image | |
| import torch | |
| from app.utils.model_cache import model_cache | |
| from app.config import MAX_FILE_SIZE, ALLOWED_EXTENSIONS | |
| # Import EfficientNet caption generation function | |
| try: | |
| from training.efficient_train import generate_caption | |
| except ImportError: | |
| # Fallback for backward compatibility (before reorganization) | |
| try: | |
| from efficient_train import generate_caption | |
| except ImportError: | |
| raise ImportError("Could not import generate_caption from training.efficient_train or efficient_train") | |
| logger = logging.getLogger(__name__) | |
| bp = Blueprint('main', __name__) | |
| # Image transformation for EfficientNet | |
| efficientnet_transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def allowed_file(filename): | |
| """Check if file extension is allowed.""" | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def validate_file_type(file_path): | |
| """Validate file is actually an image (not just extension).""" | |
| try: | |
| img = Image.open(file_path) | |
| img.verify() | |
| return True | |
| except Exception: | |
| return False | |
| def before_request(): | |
| """Log request start time.""" | |
| request.start_time = time.time() | |
| def after_request(response): | |
| """Add security headers and log request duration.""" | |
| # Security headers | |
| response.headers['X-Content-Type-Options'] = 'nosniff' | |
| response.headers['X-Frame-Options'] = 'DENY' | |
| response.headers['X-XSS-Protection'] = '1; mode=block' | |
| # Log request | |
| duration = time.time() - request.start_time | |
| logger.info(f"{request.method} {request.path} - {response.status_code} - {duration:.3f}s") | |
| return response | |
| def index(): | |
| """Serve the main page.""" | |
| return render_template('index.html') | |
| def health_check(): | |
| """Health check endpoint for load balancers.""" | |
| return jsonify({ | |
| 'status': 'healthy', | |
| 'timestamp': datetime.utcnow().isoformat(), | |
| 'models_loaded': { | |
| 'efficientnet': model_cache.is_efficientnet_loaded() | |
| } | |
| }), 200 | |
| def readiness_check(): | |
| """Readiness check - ensures EfficientNet model is loaded.""" | |
| if not model_cache.is_efficientnet_loaded(): | |
| return jsonify({'status': 'not ready', 'reason': 'EfficientNet model not loaded'}), 503 | |
| return jsonify({'status': 'ready'}), 200 | |
| def upload_file(): | |
| """Handle image upload and generate caption.""" | |
| if 'image' not in request.files: | |
| logger.warning("Upload request missing 'image' field") | |
| return jsonify({'error': 'No file part'}), 400 | |
| file = request.files['image'] | |
| model_choice = request.form.get('model', 'efficientnet') # Default to EfficientNet | |
| if file.filename == '': | |
| return jsonify({'error': 'No selected file'}), 400 | |
| if not file or not allowed_file(file.filename): | |
| return jsonify({'error': 'Invalid file type. Only PNG, JPG, JPEG allowed.'}), 400 | |
| # Get upload folder from current app (set in __init__.py) | |
| from flask import current_app | |
| upload_folder = current_app.config['UPLOAD_FOLDER'] | |
| # Save file temporarily | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(upload_folder, filename) | |
| try: | |
| file.save(filepath) | |
| # Validate file size | |
| file_size = os.path.getsize(filepath) | |
| if file_size > MAX_FILE_SIZE: | |
| os.remove(filepath) | |
| return jsonify({'error': f'File too large. Maximum size: {MAX_FILE_SIZE / 1024 / 1024}MB'}), 400 | |
| # Validate file is actually an image | |
| if not validate_file_type(filepath): | |
| os.remove(filepath) | |
| return jsonify({'error': 'Invalid image file'}), 400 | |
| # Generate caption based on model choice | |
| start_time = time.time() | |
| if model_choice == 'efficientnet': | |
| if not model_cache.is_efficientnet_loaded(): | |
| return jsonify({'error': 'EfficientNet model not available'}), 503 | |
| model, tokenizer = model_cache.get_efficientnet_model() | |
| # Load and preprocess image | |
| image = Image.open(filepath).convert('RGB') | |
| image_tensor = efficientnet_transform(image).to(model_cache._device) | |
| # Generate caption | |
| with torch.no_grad(): | |
| caption = generate_caption( | |
| model, | |
| image_tensor, | |
| tokenizer, | |
| model_cache._device, | |
| max_length=64 | |
| ) | |
| else: # resnet50 (not supported - EfficientNet only) | |
| return jsonify({'error': 'ResNet model is not available. Please use EfficientNet.'}), 400 | |
| inference_time = time.time() - start_time | |
| logger.info(f"Caption generated in {inference_time:.3f}s using {model_choice}") | |
| # Clean up uploaded file | |
| os.remove(filepath) | |
| return jsonify({ | |
| 'success': True, | |
| 'caption': caption, | |
| 'model': model_choice, | |
| 'inference_time': round(inference_time, 3) | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error generating caption: {e}", exc_info=True) | |
| # Clean up file on error | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| return jsonify({'error': 'Failed to generate caption. Please try again.'}), 500 | |