|
|
from flask import Flask, request, jsonify |
|
|
from flask_cors import CORS |
|
|
import torch |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
import logging |
|
|
from model_loader import load_caption_model, load_action_model |
|
|
from inference import generate_caption, predict_action |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
caption_model = None |
|
|
action_model = None |
|
|
vocab = None |
|
|
device = None |
|
|
|
|
|
@app.route('/') |
|
|
def home(): |
|
|
"""Home endpoint""" |
|
|
return jsonify({ |
|
|
'message': 'Image Captioning & Action Recognition API', |
|
|
'status': 'running', |
|
|
'endpoints': { |
|
|
'health': '/health', |
|
|
'caption': '/api/caption', |
|
|
'action': '/api/action', |
|
|
'combined': '/api/combined' |
|
|
} |
|
|
}) |
|
|
|
|
|
@app.route('/health') |
|
|
def health(): |
|
|
"""Health check endpoint""" |
|
|
return jsonify({ |
|
|
'status': 'healthy', |
|
|
'models_loaded': { |
|
|
'caption_model': caption_model is not None, |
|
|
'action_model': action_model is not None, |
|
|
'vocab': vocab is not None |
|
|
}, |
|
|
'device': str(device) |
|
|
}) |
|
|
|
|
|
@app.route('/api/caption', methods=['POST']) |
|
|
def caption_image(): |
|
|
""" |
|
|
Generate caption for uploaded image |
|
|
|
|
|
Expected: multipart/form-data with 'image' file |
|
|
Returns: JSON with generated caption |
|
|
""" |
|
|
try: |
|
|
|
|
|
if 'image' not in request.files: |
|
|
return jsonify({'error': 'No image provided'}), 400 |
|
|
|
|
|
file = request.files['image'] |
|
|
|
|
|
|
|
|
image_bytes = file.read() |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
|
|
|
|
|
|
caption = generate_caption(caption_model, image, vocab, device) |
|
|
|
|
|
logger.info(f"Caption generated: {caption}") |
|
|
|
|
|
return jsonify({ |
|
|
'success': True, |
|
|
'caption': caption |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in caption generation: {str(e)}") |
|
|
return jsonify({ |
|
|
'success': False, |
|
|
'error': str(e) |
|
|
}), 500 |
|
|
|
|
|
@app.route('/api/action', methods=['POST']) |
|
|
def recognize_action(): |
|
|
""" |
|
|
Recognize action in uploaded image |
|
|
|
|
|
Expected: multipart/form-data with 'image' file |
|
|
Returns: JSON with predicted action and confidence |
|
|
""" |
|
|
try: |
|
|
|
|
|
if 'image' not in request.files: |
|
|
return jsonify({'error': 'No image provided'}), 400 |
|
|
|
|
|
file = request.files['image'] |
|
|
|
|
|
|
|
|
image_bytes = file.read() |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
|
|
|
|
|
|
result = predict_action(action_model, image, device) |
|
|
|
|
|
logger.info(f"Action predicted: {result['predicted_class']} ({result['confidence']:.2f}%)") |
|
|
|
|
|
return jsonify({ |
|
|
'success': True, |
|
|
'predicted_action': result['predicted_class'], |
|
|
'confidence': result['confidence'], |
|
|
'all_predictions': result['all_predictions'] |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in action recognition: {str(e)}") |
|
|
return jsonify({ |
|
|
'success': False, |
|
|
'error': str(e) |
|
|
}), 500 |
|
|
|
|
|
@app.route('/api/combined', methods=['POST']) |
|
|
def combined_inference(): |
|
|
""" |
|
|
Perform both captioning and action recognition |
|
|
|
|
|
Expected: multipart/form-data with 'image' file |
|
|
Returns: JSON with both caption and action prediction |
|
|
""" |
|
|
try: |
|
|
|
|
|
if 'image' not in request.files: |
|
|
return jsonify({'error': 'No image provided'}), 400 |
|
|
|
|
|
file = request.files['image'] |
|
|
|
|
|
|
|
|
image_bytes = file.read() |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
|
|
|
|
|
|
caption = generate_caption(caption_model, image, vocab, device) |
|
|
|
|
|
|
|
|
action_result = predict_action(action_model, image, device) |
|
|
|
|
|
logger.info(f"Combined - Caption: {caption}, Action: {action_result['predicted_class']}") |
|
|
|
|
|
return jsonify({ |
|
|
'success': True, |
|
|
'caption': caption, |
|
|
'action': { |
|
|
'predicted_action': action_result['predicted_class'], |
|
|
'confidence': action_result['confidence'], |
|
|
'all_predictions': action_result['all_predictions'] |
|
|
} |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in combined inference: {str(e)}") |
|
|
return jsonify({ |
|
|
'success': False, |
|
|
'error': str(e) |
|
|
}), 500 |
|
|
|
|
|
def initialize_models(): |
|
|
global caption_model, action_model, vocab, device |
|
|
|
|
|
logger.info("Initializing models...") |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
try: |
|
|
caption_model, vocab = load_caption_model(device) |
|
|
logger.info(" Caption model loaded") |
|
|
|
|
|
action_model = load_action_model(device) |
|
|
logger.info(" Action model loaded") |
|
|
|
|
|
logger.info("All models initialized successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading models: {str(e)}") |
|
|
raise |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
initialize_models() |
|
|
|
|
|
|
|
|
app.run( |
|
|
host='0.0.0.0', |
|
|
port=7860, |
|
|
debug=False |
|
|
) |