AhsanAftab commited on
Commit
c8b4cd2
·
verified ·
1 Parent(s): 7cd9526

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +206 -0
  2. inference.py +121 -0
  3. model_loader.py +138 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flask REST API for Image Captioning and Action Recognition
3
+ """
4
+
5
+ from flask import Flask, request, jsonify
6
+ from flask_cors import CORS
7
+ import torch
8
+ from PIL import Image
9
+ import io
10
+ import base64
11
+ import logging
12
+ from model_loader import load_caption_model, load_action_model, load_vocab
13
+ from inference import generate_caption, predict_action
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Initialize Flask app
20
+ app = Flask(__name__)
21
+ CORS(app) # Enable CORS for frontend communication
22
+
23
+ # Global variables for models
24
+ caption_model = None
25
+ action_model = None
26
+ vocab = None
27
+ device = None
28
+
29
+ @app.route('/')
30
+ def home():
31
+ """Home endpoint"""
32
+ return jsonify({
33
+ 'message': 'Image Captioning & Action Recognition API',
34
+ 'status': 'running',
35
+ 'endpoints': {
36
+ 'health': '/health',
37
+ 'caption': '/api/caption',
38
+ 'action': '/api/action',
39
+ 'combined': '/api/combined'
40
+ }
41
+ })
42
+
43
+ @app.route('/health')
44
+ def health():
45
+ """Health check endpoint"""
46
+ return jsonify({
47
+ 'status': 'healthy',
48
+ 'models_loaded': {
49
+ 'caption_model': caption_model is not None,
50
+ 'action_model': action_model is not None,
51
+ 'vocab': vocab is not None
52
+ },
53
+ 'device': str(device)
54
+ })
55
+
56
+ @app.route('/api/caption', methods=['POST'])
57
+ def caption_image():
58
+ """
59
+ Generate caption for uploaded image
60
+
61
+ Expected: multipart/form-data with 'image' file
62
+ Returns: JSON with generated caption
63
+ """
64
+ try:
65
+ # Check if image is in request
66
+ if 'image' not in request.files:
67
+ return jsonify({'error': 'No image provided'}), 400
68
+
69
+ file = request.files['image']
70
+
71
+ # Read image
72
+ image_bytes = file.read()
73
+ image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
74
+
75
+ # Generate caption
76
+ caption = generate_caption(caption_model, image, vocab, device)
77
+
78
+ logger.info(f"Caption generated: {caption}")
79
+
80
+ return jsonify({
81
+ 'success': True,
82
+ 'caption': caption
83
+ })
84
+
85
+ except Exception as e:
86
+ logger.error(f"Error in caption generation: {str(e)}")
87
+ return jsonify({
88
+ 'success': False,
89
+ 'error': str(e)
90
+ }), 500
91
+
92
+ @app.route('/api/action', methods=['POST'])
93
+ def recognize_action():
94
+ """
95
+ Recognize action in uploaded image
96
+
97
+ Expected: multipart/form-data with 'image' file
98
+ Returns: JSON with predicted action and confidence
99
+ """
100
+ try:
101
+ # Check if image is in request
102
+ if 'image' not in request.files:
103
+ return jsonify({'error': 'No image provided'}), 400
104
+
105
+ file = request.files['image']
106
+
107
+ # Read image
108
+ image_bytes = file.read()
109
+ image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
110
+
111
+ # Predict action
112
+ result = predict_action(action_model, image, device)
113
+
114
+ logger.info(f"Action predicted: {result['predicted_class']} ({result['confidence']:.2f}%)")
115
+
116
+ return jsonify({
117
+ 'success': True,
118
+ 'predicted_action': result['predicted_class'],
119
+ 'confidence': result['confidence'],
120
+ 'all_predictions': result['all_predictions']
121
+ })
122
+
123
+ except Exception as e:
124
+ logger.error(f"Error in action recognition: {str(e)}")
125
+ return jsonify({
126
+ 'success': False,
127
+ 'error': str(e)
128
+ }), 500
129
+
130
+ @app.route('/api/combined', methods=['POST'])
131
+ def combined_inference():
132
+ """
133
+ Perform both captioning and action recognition
134
+
135
+ Expected: multipart/form-data with 'image' file
136
+ Returns: JSON with both caption and action prediction
137
+ """
138
+ try:
139
+ # Check if image is in request
140
+ if 'image' not in request.files:
141
+ return jsonify({'error': 'No image provided'}), 400
142
+
143
+ file = request.files['image']
144
+
145
+ # Read image
146
+ image_bytes = file.read()
147
+ image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
148
+
149
+ # Generate caption
150
+ caption = generate_caption(caption_model, image, vocab, device)
151
+
152
+ # Predict action
153
+ action_result = predict_action(action_model, image, device)
154
+
155
+ logger.info(f"Combined - Caption: {caption}, Action: {action_result['predicted_class']}")
156
+
157
+ return jsonify({
158
+ 'success': True,
159
+ 'caption': caption,
160
+ 'action': {
161
+ 'predicted_action': action_result['predicted_class'],
162
+ 'confidence': action_result['confidence'],
163
+ 'all_predictions': action_result['all_predictions']
164
+ }
165
+ })
166
+
167
+ except Exception as e:
168
+ logger.error(f"Error in combined inference: {str(e)}")
169
+ return jsonify({
170
+ 'success': False,
171
+ 'error': str(e)
172
+ }), 500
173
+
174
+ def initialize_models():
175
+ global caption_model, action_model, vocab, device
176
+
177
+ logger.info("Initializing models...")
178
+
179
+ # Set device
180
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
181
+ logger.info(f"Using device: {device}")
182
+
183
+ # Load models
184
+ try:
185
+ caption_model, vocab = load_caption_model(device)
186
+ logger.info(" Caption model loaded")
187
+
188
+ action_model = load_action_model(device)
189
+ logger.info(" Action model loaded")
190
+
191
+ logger.info("All models initialized successfully!")
192
+
193
+ except Exception as e:
194
+ logger.error(f"Error loading models: {str(e)}")
195
+ raise
196
+
197
+ if __name__ == '__main__':
198
+ # Initialize models
199
+ initialize_models()
200
+
201
+ # Run Flask app
202
+ app.run(
203
+ host='0.0.0.0',
204
+ port=5000,
205
+ debug=False # Set to False in production
206
+ )
inference.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import pickle
5
+ from pathlib import Path
6
+
7
+ # Image transformations
8
+ transform = transforms.Compose([
9
+ transforms.Resize((224, 224)),
10
+ transforms.ToTensor(),
11
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
12
+ std=[0.229, 0.224, 0.225])
13
+ ])
14
+
15
+ # Load action class names (we'll load this once)
16
+ _action_class_names = None
17
+
18
+ def get_action_class_names():
19
+ """Load action class names"""
20
+ global _action_class_names
21
+ if _action_class_names is None:
22
+ model_dir = Path(__file__).parent.parent / 'models'
23
+ with open(model_dir / 'action_model_config.pkl', 'rb') as f:
24
+ config = pickle.load(f)
25
+ _action_class_names = config['class_names']
26
+ return _action_class_names
27
+
28
+ def generate_caption(model, image, vocab, device, max_length=30):
29
+ """
30
+ Generate caption for an image
31
+
32
+ Args:
33
+ model: Trained caption model
34
+ image: PIL Image
35
+ vocab: Vocabulary object
36
+ device: torch device
37
+ max_length: Maximum caption length
38
+
39
+ Returns:
40
+ caption: Generated caption string
41
+ """
42
+ model.eval()
43
+
44
+ # Transform image
45
+ image_tensor = transform(image).unsqueeze(0).to(device)
46
+
47
+ # Generate caption
48
+ with torch.no_grad():
49
+ caption_indices = model.generate_caption(image_tensor, max_length)
50
+
51
+ # Decode caption
52
+ caption_indices = caption_indices[0].cpu().numpy()
53
+ caption_words = vocab.decode(caption_indices)
54
+
55
+ # Remove special tokens and create caption
56
+ caption = []
57
+ for word in caption_words:
58
+ if word == vocab.start_token:
59
+ continue
60
+ if word == vocab.end_token:
61
+ break
62
+ if word == vocab.pad_token:
63
+ break
64
+ caption.append(word)
65
+
66
+ caption_text = ' '.join(caption)
67
+
68
+ # Capitalize first letter
69
+ if caption_text:
70
+ caption_text = caption_text[0].upper() + caption_text[1:]
71
+
72
+ return caption_text
73
+
74
+ def predict_action(model, image, device):
75
+ """
76
+ Predict action for an image
77
+
78
+ Args:
79
+ model: Trained action model
80
+ image: PIL Image
81
+ device: torch device
82
+
83
+ Returns:
84
+ dict: Prediction results with class, confidence, and all predictions
85
+ """
86
+ model.eval()
87
+
88
+ # Get class names
89
+ class_names = get_action_class_names()
90
+
91
+ # Transform image
92
+ image_tensor = transform(image).unsqueeze(0).to(device)
93
+
94
+ # Predict
95
+ with torch.no_grad():
96
+ outputs = model(image_tensor)
97
+ probabilities = torch.softmax(outputs, dim=1)
98
+ confidence, predicted_idx = probabilities.max(dim=1)
99
+
100
+ predicted_class = class_names[predicted_idx.item()]
101
+ confidence_percent = confidence.item() * 100
102
+
103
+ # Get all predictions (sorted by probability)
104
+ all_probs = probabilities[0].cpu().numpy() * 100
105
+
106
+ # Create list of all predictions
107
+ all_predictions = []
108
+ for idx, prob in enumerate(all_probs):
109
+ all_predictions.append({
110
+ 'class': class_names[idx],
111
+ 'probability': float(prob)
112
+ })
113
+
114
+ # Sort by probability
115
+ all_predictions = sorted(all_predictions, key=lambda x: x['probability'], reverse=True)
116
+
117
+ return {
118
+ 'predicted_class': predicted_class,
119
+ 'confidence': float(confidence_percent),
120
+ 'all_predictions': all_predictions[:5] # Return top 5
121
+ }
model_loader.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import pickle
5
+ from pathlib import Path
6
+
7
+ # Model architecture classes (same as in training)
8
+ class EncoderCNN(nn.Module):
9
+ def __init__(self, embed_size):
10
+ super(EncoderCNN, self).__init__()
11
+ resnet = models.resnet50(pretrained=False)
12
+ modules = list(resnet.children())[:-1]
13
+ self.resnet = nn.Sequential(*modules)
14
+ self.fc = nn.Linear(resnet.fc.in_features, embed_size)
15
+ self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
16
+
17
+ def forward(self, images):
18
+ features = self.resnet(images)
19
+ features = features.view(features.size(0), -1)
20
+ features = self.fc(features)
21
+ features = self.bn(features)
22
+ return features
23
+
24
+ class DecoderLSTM(nn.Module):
25
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, dropout=0.5):
26
+ super(DecoderLSTM, self).__init__()
27
+ self.embed = nn.Embedding(vocab_size, embed_size)
28
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers,
29
+ batch_first=True, dropout=dropout if num_layers > 1 else 0)
30
+ self.dropout = nn.Dropout(dropout)
31
+ self.fc = nn.Linear(hidden_size, vocab_size)
32
+
33
+ def forward(self, features, captions):
34
+ embeddings = self.embed(captions)
35
+ embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
36
+ hiddens, _ = self.lstm(embeddings)
37
+ outputs = self.fc(hiddens)
38
+ return outputs
39
+
40
+ def sample(self, features, max_length=50):
41
+ batch_size = features.size(0)
42
+ captions = []
43
+ states = None
44
+ inputs = features.unsqueeze(1)
45
+
46
+ for _ in range(max_length):
47
+ hiddens, states = self.lstm(inputs, states)
48
+ outputs = self.fc(hiddens.squeeze(1))
49
+ predicted = outputs.argmax(dim=1)
50
+ captions.append(predicted)
51
+ inputs = self.embed(predicted).unsqueeze(1)
52
+
53
+ captions = torch.stack(captions, dim=1)
54
+ return captions
55
+
56
+ class ImageCaptioningModel(nn.Module):
57
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, dropout=0.5):
58
+ super(ImageCaptioningModel, self).__init__()
59
+ self.encoder = EncoderCNN(embed_size)
60
+ self.decoder = DecoderLSTM(embed_size, hidden_size, vocab_size, num_layers, dropout)
61
+
62
+ def forward(self, images, captions):
63
+ features = self.encoder(images)
64
+ outputs = self.decoder(features, captions)
65
+ return outputs
66
+
67
+ def generate_caption(self, images, max_length=50):
68
+ features = self.encoder(images)
69
+ captions = self.decoder.sample(features, max_length)
70
+ return captions
71
+
72
+ class ActionRecognitionModel(nn.Module):
73
+ def __init__(self, num_classes, dropout=0.5):
74
+ super(ActionRecognitionModel, self).__init__()
75
+ self.backbone = models.resnet50(pretrained=False)
76
+ num_features = self.backbone.fc.in_features
77
+
78
+ self.backbone.fc = nn.Sequential(
79
+ nn.Dropout(dropout),
80
+ nn.Linear(num_features, 512),
81
+ nn.ReLU(),
82
+ nn.BatchNorm1d(512),
83
+ nn.Dropout(dropout),
84
+ nn.Linear(512, num_classes)
85
+ )
86
+
87
+ def forward(self, x):
88
+ return self.backbone(x)
89
+
90
+ def load_caption_model(device, model_dir='../models'):
91
+ model_dir = Path(model_dir)
92
+
93
+ # Load configuration
94
+ with open(model_dir / 'caption_model_config.pkl', 'rb') as f:
95
+ config = pickle.load(f)
96
+
97
+ # Load vocabulary
98
+ with open(model_dir / 'vocab.pkl', 'rb') as f:
99
+ vocab = pickle.load(f)
100
+
101
+ # Create model
102
+ model = ImageCaptioningModel(
103
+ embed_size=config['embed_size'],
104
+ hidden_size=config['hidden_size'],
105
+ vocab_size=config['vocab_size'],
106
+ num_layers=config['num_layers'],
107
+ dropout=config['dropout']
108
+ )
109
+
110
+ # Load weights
111
+ model.load_state_dict(torch.load(model_dir / 'caption_model_final.pth',
112
+ map_location=device))
113
+ model = model.to(device)
114
+ model.eval()
115
+
116
+ return model, vocab
117
+
118
+ def load_action_model(device, model_dir='../models'):
119
+ """Load action recognition model"""
120
+ model_dir = Path(model_dir)
121
+
122
+ # Load configuration
123
+ with open(model_dir / 'action_model_config.pkl', 'rb') as f:
124
+ config = pickle.load(f)
125
+
126
+ # Create model
127
+ model = ActionRecognitionModel(
128
+ num_classes=config['num_classes'],
129
+ dropout=config['dropout']
130
+ )
131
+
132
+ # Load weights
133
+ model.load_state_dict(torch.load(model_dir / 'action_model_final.pth',
134
+ map_location=device))
135
+ model = model.to(device)
136
+ model.eval()
137
+
138
+ return model