File size: 5,808 Bytes
a08c4fe 273b844 a08c4fe c8b4cd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
from PIL import Image
import io
import base64
import logging
from model_loader import load_caption_model, load_action_model
from inference import generate_caption, predict_action
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Flask app
app = Flask(__name__)
CORS(app) # Enable CORS for frontend communication
# Global variables for models
caption_model = None
action_model = None
vocab = None
device = None
@app.route('/')
def home():
"""Home endpoint"""
return jsonify({
'message': 'Image Captioning & Action Recognition API',
'status': 'running',
'endpoints': {
'health': '/health',
'caption': '/api/caption',
'action': '/api/action',
'combined': '/api/combined'
}
})
@app.route('/health')
def health():
"""Health check endpoint"""
return jsonify({
'status': 'healthy',
'models_loaded': {
'caption_model': caption_model is not None,
'action_model': action_model is not None,
'vocab': vocab is not None
},
'device': str(device)
})
@app.route('/api/caption', methods=['POST'])
def caption_image():
"""
Generate caption for uploaded image
Expected: multipart/form-data with 'image' file
Returns: JSON with generated caption
"""
try:
# Check if image is in request
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
file = request.files['image']
# Read image
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Generate caption
caption = generate_caption(caption_model, image, vocab, device)
logger.info(f"Caption generated: {caption}")
return jsonify({
'success': True,
'caption': caption
})
except Exception as e:
logger.error(f"Error in caption generation: {str(e)}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/action', methods=['POST'])
def recognize_action():
"""
Recognize action in uploaded image
Expected: multipart/form-data with 'image' file
Returns: JSON with predicted action and confidence
"""
try:
# Check if image is in request
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
file = request.files['image']
# Read image
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Predict action
result = predict_action(action_model, image, device)
logger.info(f"Action predicted: {result['predicted_class']} ({result['confidence']:.2f}%)")
return jsonify({
'success': True,
'predicted_action': result['predicted_class'],
'confidence': result['confidence'],
'all_predictions': result['all_predictions']
})
except Exception as e:
logger.error(f"Error in action recognition: {str(e)}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/combined', methods=['POST'])
def combined_inference():
"""
Perform both captioning and action recognition
Expected: multipart/form-data with 'image' file
Returns: JSON with both caption and action prediction
"""
try:
# Check if image is in request
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
file = request.files['image']
# Read image
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Generate caption
caption = generate_caption(caption_model, image, vocab, device)
# Predict action
action_result = predict_action(action_model, image, device)
logger.info(f"Combined - Caption: {caption}, Action: {action_result['predicted_class']}")
return jsonify({
'success': True,
'caption': caption,
'action': {
'predicted_action': action_result['predicted_class'],
'confidence': action_result['confidence'],
'all_predictions': action_result['all_predictions']
}
})
except Exception as e:
logger.error(f"Error in combined inference: {str(e)}")
return jsonify({
'success': False,
'error': str(e)
}), 500
def initialize_models():
global caption_model, action_model, vocab, device
logger.info("Initializing models...")
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Load models
try:
caption_model, vocab = load_caption_model(device)
logger.info(" Caption model loaded")
action_model = load_action_model(device)
logger.info(" Action model loaded")
logger.info("All models initialized successfully!")
except Exception as e:
logger.error(f"Error loading models: {str(e)}")
raise
if __name__ == '__main__':
# Initialize models
initialize_models()
# Run Flask app
app.run(
host='0.0.0.0',
port=7860,
debug=False # Set to False in production
) |