VIKRAM989 commited on
Commit
40243b5
·
1 Parent(s): 73ee6ca

Add application file

Browse files
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . /app
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ EXPOSE 7860
10
+
11
+ CMD ["python", "main.py"]
__pycache__/main.cpython-312.pyc ADDED
Binary file (3.7 kB). View file
 
__pycache__/model.cpython-312.pyc ADDED
Binary file (10.3 kB). View file
 
debug_weights.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ import sys
5
+ import os
6
+ import pickle
7
+ import re
8
+ from collections import Counter
9
+
10
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ EMBED_DIM = 512
13
+ HIDDEN_DIM = 512
14
+ MAX_LEN = 25
15
+
16
+ # Vocabulary class
17
+ class Vocabulary:
18
+ def __init__(self, freq_threshold=5):
19
+ self.freq_threshold = freq_threshold
20
+ self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"}
21
+ self.stoi = {v: k for k, v in self.itos.items()}
22
+ self.index = 4
23
+
24
+ def __len__(self):
25
+ return len(self.itos)
26
+
27
+ def tokenizer(self, text):
28
+ text = text.lower()
29
+ tokens = re.findall(r"\w+", text)
30
+ return tokens
31
+
32
+ def build_vocabulary(self, sentence_list):
33
+ frequencies = Counter()
34
+ for sentence in sentence_list:
35
+ tokens = self.tokenizer(sentence)
36
+ frequencies.update(tokens)
37
+
38
+ for word, freq in frequencies.items():
39
+ if freq >= self.freq_threshold:
40
+ self.stoi[word] = self.index
41
+ self.itos[self.index] = word
42
+ self.index += 1
43
+
44
+ def numericalize(self, text):
45
+ tokens = self.tokenizer(text)
46
+ numericalized = []
47
+ for token in tokens:
48
+ if token in self.stoi:
49
+ numericalized.append(self.stoi[token])
50
+ else:
51
+ numericalized.append(self.stoi["unk"])
52
+ return numericalized
53
+
54
+
55
+ class Encoder(nn.Module):
56
+ def __init__(self, embed_dim):
57
+ super().__init__()
58
+ resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
59
+ self.backbone = nn.Sequential(*list(resnet.children())[:-1])
60
+ self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
61
+ self.bn = nn.BatchNorm1d(embed_dim)
62
+
63
+ def forward(self, x):
64
+ with torch.no_grad():
65
+ features = self.backbone(x)
66
+ features = features.reshape(features.size(0), -1)
67
+ features = self.bn(self.fc(features))
68
+ return features
69
+
70
+
71
+ class Decoder(nn.Module):
72
+ def __init__(self, embed_dim, hidden_dim, vocab_size):
73
+ super().__init__()
74
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
75
+ self.lstm = nn.LSTM(
76
+ embed_dim,
77
+ hidden_dim,
78
+ batch_first=True
79
+ )
80
+ self.fc = nn.Linear(hidden_dim, vocab_size)
81
+
82
+ def forward(self, x, states=None):
83
+ emb = self.embedding(x)
84
+ outputs, states = self.lstm(emb, states)
85
+ logits = self.fc(outputs)
86
+ return logits, states
87
+
88
+
89
+ class CaptionModel(nn.Module):
90
+ def __init__(self, embed_dim, hidden_dim, vocab_size):
91
+ super().__init__()
92
+ self.encoder = Encoder(embed_dim)
93
+ self.decoder = Decoder(embed_dim, hidden_dim, vocab_size)
94
+
95
+
96
+ # Main debug
97
+ script_dir = os.path.dirname(os.path.abspath(__file__))
98
+ CHECKPOINT_PATH = os.path.join(script_dir, "best_checkpoint.pth")
99
+ VOCAB_PATH = os.path.join(script_dir, "vocab.pkl")
100
+
101
+ print("=" * 80)
102
+ print("LOADING CHECKPOINT")
103
+ print("=" * 80)
104
+
105
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
106
+ print(f"\nCheckpoint keys: {list(checkpoint.keys())}")
107
+
108
+ print("\nCheckpoint model_state_dict keys:")
109
+ checkpoint_keys = set(checkpoint["model_state_dict"].keys())
110
+ for key in sorted(checkpoint_keys):
111
+ shape = checkpoint["model_state_dict"][key].shape
112
+ print(f" {key}: {shape}")
113
+
114
+ # Load vocab
115
+ with open(VOCAB_PATH, "rb") as f:
116
+ vocab = pickle.load(f)
117
+
118
+ vocab_size = len(vocab)
119
+ print(f"\nVocab size: {vocab_size}")
120
+
121
+ # Create model
122
+ model = CaptionModel(
123
+ EMBED_DIM,
124
+ HIDDEN_DIM,
125
+ vocab_size
126
+ ).to(DEVICE)
127
+
128
+ print("\n" + "=" * 80)
129
+ print("MODEL STATE DICT KEYS")
130
+ print("=" * 80)
131
+
132
+ model_keys = set(model.state_dict().keys())
133
+ for key in sorted(model_keys):
134
+ shape = model.state_dict()[key].shape
135
+ print(f" {key}: {shape}")
136
+
137
+ # Check differences
138
+ print("\n" + "=" * 80)
139
+ print("COMPARISON")
140
+ print("=" * 80)
141
+
142
+ print("\nKeys in checkpoint but NOT in model:")
143
+ for key in sorted(checkpoint_keys - model_keys):
144
+ print(f" {key}")
145
+
146
+ print("\nKeys in model but NOT in checkpoint:")
147
+ for key in sorted(model_keys - checkpoint_keys):
148
+ print(f" {key}")
149
+
150
+ print("\nKeys in both but with different shapes:")
151
+ for key in sorted(checkpoint_keys & model_keys):
152
+ cp_shape = checkpoint["model_state_dict"][key].shape
153
+ model_shape = model.state_dict()[key].shape
154
+ if cp_shape != model_shape:
155
+ print(f" {key}")
156
+ print(f" Checkpoint: {cp_shape}")
157
+ print(f" Model: {model_shape}")
158
+
159
+ print("\n" + "=" * 80)
160
+ print("ATTEMPTING TO LOAD WEIGHTS")
161
+ print("=" * 80)
162
+
163
+ try:
164
+ model.load_state_dict(checkpoint["model_state_dict"])
165
+ print("SUCCESS: Weights loaded successfully!")
166
+ except Exception as e:
167
+ print(f"ERROR: {e}")
main.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from PIL import Image
4
+ import io
5
+ import torch
6
+ import pickle
7
+ import os
8
+ import uvicorn
9
+ # Import from model.py
10
+ from model import (
11
+ Vocabulary,
12
+ ResNetEncoder,
13
+ DecoderLSTM,
14
+ ImageCaptioningModel,
15
+ generate_caption,
16
+ transform,
17
+ EMBED_DIM,
18
+ HIDDEN_DIM,
19
+ )
20
+
21
+ app = FastAPI(title="Image Captioning API")
22
+
23
+ # -------------------------
24
+ # Enable CORS
25
+ # -------------------------
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # -------------------------
37
+ # Paths (relative to main.py)
38
+ # -------------------------
39
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
40
+
41
+ VOCAB_PATH = os.path.join(BASE_DIR, "vocab.pkl")
42
+ CHECKPOINT_PATH = os.path.join(BASE_DIR, "best_checkpoint.pth")
43
+
44
+ # -------------------------
45
+ # Load Vocabulary
46
+ # -------------------------
47
+ class CustomUnpickler(pickle.Unpickler):
48
+ def find_class(self, module, name):
49
+ if name == "Vocabulary":
50
+ return Vocabulary
51
+ return super().find_class(module, name)
52
+
53
+ with open(VOCAB_PATH, "rb") as f:
54
+ vocab = CustomUnpickler(f).load()
55
+
56
+ vocab_size = len(vocab)
57
+
58
+ # -------------------------
59
+ # Build Model
60
+ # -------------------------
61
+ encoder = ResNetEncoder(EMBED_DIM)
62
+ decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
63
+
64
+ model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
65
+
66
+ # -------------------------
67
+ # Load Weights
68
+ # -------------------------
69
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
70
+ model.load_state_dict(checkpoint["model_state_dict"])
71
+
72
+ model.eval()
73
+
74
+ print("✅ Model Loaded Successfully")
75
+
76
+ # -------------------------
77
+ # Health Check
78
+ # -------------------------
79
+ @app.get("/")
80
+ def root():
81
+ return {"message": "Image Captioning API Running"}
82
+
83
+ # -------------------------
84
+ # Caption Endpoint
85
+ # -------------------------
86
+ @app.post("/caption")
87
+ async def caption_image(file: UploadFile = File(...)):
88
+
89
+ contents = await file.read()
90
+
91
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
92
+
93
+ image = transform(image)
94
+
95
+ caption = generate_caption(model, image, vocab)
96
+
97
+ return {
98
+ "caption": caption
99
+ }
100
+
101
+ if __name__ == "__main__":
102
+
103
+ uvicorn.run("main:app", host="0.0.0.0", port=7860)
model.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ import torchvision.models as models
5
+ from PIL import Image
6
+ import pickle
7
+ import sys
8
+ import os
9
+ import re
10
+ from collections import Counter
11
+
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ EMBED_DIM = 512
15
+ HIDDEN_DIM = 512
16
+ MAX_LEN = 25
17
+
18
+
19
+ # -----------------------
20
+ # Vocabulary
21
+ # -----------------------
22
+ class Vocabulary:
23
+ def __init__(self, freq_threshold=5):
24
+ self.freq_threshold = freq_threshold
25
+ self.itos = {0: "pad", 1: "startofseq", 2: "endofseq", 3: "unk"}
26
+ self.stoi = {v: k for k, v in self.itos.items()}
27
+ self.index = 4
28
+
29
+ def __len__(self):
30
+ return len(self.itos)
31
+
32
+ def tokenizer(self, text):
33
+ text = text.lower()
34
+ tokens = re.findall(r"\w+", text)
35
+ return tokens
36
+
37
+ def build_vocabulary(self, sentence_list):
38
+ frequencies = Counter()
39
+ for sentence in sentence_list:
40
+ tokens = self.tokenizer(sentence)
41
+ frequencies.update(tokens)
42
+
43
+ for word, freq in frequencies.items():
44
+ if freq >= self.freq_threshold:
45
+ self.stoi[word] = self.index
46
+ self.itos[self.index] = word
47
+ self.index += 1
48
+
49
+ def numericalize(self, text):
50
+ tokens = self.tokenizer(text)
51
+ numericalized = []
52
+ for token in tokens:
53
+ if token in self.stoi:
54
+ numericalized.append(self.stoi[token])
55
+ else:
56
+ numericalized.append(self.stoi["unk"])
57
+ return numericalized
58
+
59
+
60
+ # -----------------------
61
+ # Encoder
62
+ # -----------------------
63
+ class ResNetEncoder(nn.Module):
64
+ def __init__(self, embed_dim):
65
+ super().__init__()
66
+ resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
67
+ for param in resnet.parameters():
68
+ param.requires_grad = True
69
+ modules = list(resnet.children())[:-1]
70
+ self.resnet = nn.Sequential(*modules)
71
+
72
+ self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
73
+ self.batch_norm = nn.BatchNorm1d(embed_dim, momentum=0.01)
74
+
75
+ def forward(self, images):
76
+ with torch.no_grad():
77
+ features = self.resnet(images) # (batch_size, 2048, 1, 1)
78
+ features = features.view(features.size(0), -1)
79
+ features = self.fc(features)
80
+ features = self.batch_norm(features)
81
+ return features
82
+
83
+
84
+ # -----------------------
85
+ # Decoder
86
+ # -----------------------
87
+ class DecoderLSTM(nn.Module):
88
+ def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers=1):
89
+ super().__init__()
90
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
91
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
92
+ self.fc = nn.Linear(hidden_dim, vocab_size)
93
+
94
+ def forward(self, features, captions):
95
+ # remove the last token for input
96
+ captions_in = captions[:, :-1]
97
+ emb = self.embedding(captions_in)
98
+ features = features.unsqueeze(1)
99
+ lstm_input = torch.cat((features, emb), dim=1)
100
+ outputs, _ = self.lstm(lstm_input)
101
+ logits = self.fc(outputs)
102
+ return logits
103
+
104
+
105
+ # -----------------------
106
+ # Caption Model
107
+ # -----------------------
108
+ class ImageCaptioningModel(nn.Module):
109
+ def __init__(self, encoder, decoder):
110
+ super().__init__()
111
+ self.encoder = encoder
112
+ self.decoder = decoder
113
+
114
+ def forward(self, images, captions):
115
+ features = self.encoder(images)
116
+ outputs = self.decoder(features, captions)
117
+ return outputs
118
+
119
+
120
+ # -----------------------
121
+ # Caption generator
122
+ # -----------------------
123
+ def generate_caption(model, image, vocab):
124
+
125
+ model.eval()
126
+
127
+ image = image.unsqueeze(0).to(DEVICE)
128
+
129
+ with torch.no_grad():
130
+ # Get image features
131
+ features = model.encoder(image) # (1, embed_dim)
132
+
133
+ # Start with the start token
134
+ word_idx = vocab.stoi["startofseq"]
135
+ sentence = []
136
+
137
+ # Initialize hidden state for LSTM
138
+ h = None
139
+
140
+ for _ in range(MAX_LEN):
141
+ # Create input: concatenate features with embedding of previous word
142
+ word_tensor = torch.tensor([word_idx]).to(DEVICE)
143
+ emb = model.decoder.embedding(word_tensor) # (1, embed_dim)
144
+
145
+ if h is None:
146
+ # First step: concatenate features with embedding
147
+ lstm_input = torch.cat([features.unsqueeze(1), emb.unsqueeze(1)], dim=1) # (1, 2, embed_dim)
148
+ else:
149
+ lstm_input = emb.unsqueeze(1) # (1, 1, embed_dim)
150
+
151
+ # Forward through LSTM
152
+ output, h_new = model.decoder.lstm(lstm_input, h)
153
+ h = h_new
154
+
155
+ # Predict next token
156
+ logits = model.decoder.fc(output[:, -1, :]) # (1, vocab_size)
157
+ predicted = logits.argmax(1).item()
158
+
159
+ # Get token from vocab
160
+ token = vocab.itos[predicted]
161
+
162
+ if token == "endofseq":
163
+ break
164
+
165
+ sentence.append(token)
166
+ word_idx = predicted
167
+
168
+ return " ".join(sentence)
169
+
170
+
171
+ # -----------------------
172
+ # Image transform
173
+ # -----------------------
174
+ transform = transforms.Compose([
175
+ transforms.Resize((224,224)),
176
+ transforms.ToTensor(),
177
+ transforms.Normalize(
178
+ mean=[0.485,0.456,0.406],
179
+ std=[0.229,0.224,0.225]
180
+ )
181
+ ])
182
+
183
+
184
+ # -----------------------
185
+ # Main
186
+ # -----------------------
187
+ def main():
188
+
189
+ image_path = sys.argv[1]
190
+
191
+ # Get the directory where this script is located
192
+ script_dir = os.path.dirname(os.path.abspath(__file__))
193
+ CHECKPOINT_PATH = os.path.join(script_dir, "best_checkpoint.pth")
194
+ VOCAB_PATH = os.path.join(script_dir, "vocab.pkl")
195
+
196
+ # load vocab
197
+ with open(VOCAB_PATH, "rb") as f:
198
+ vocab = pickle.load(f)
199
+
200
+ vocab_size = len(vocab)
201
+
202
+ # rebuild model
203
+ encoder = ResNetEncoder(EMBED_DIM)
204
+ decoder = DecoderLSTM(EMBED_DIM, HIDDEN_DIM, vocab_size)
205
+ model = ImageCaptioningModel(encoder, decoder).to(DEVICE)
206
+
207
+ # load checkpoint
208
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
209
+
210
+ model.load_state_dict(checkpoint["model_state_dict"])
211
+
212
+ model.eval()
213
+
214
+ # load image
215
+ img = Image.open(image_path).convert("RGB")
216
+ img = transform(img)
217
+
218
+ caption = generate_caption(model, img, vocab)
219
+
220
+ print("\nCaption:", caption)
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ torchvision
5
+ pillow
6
+ python-multipart
vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3878a91256421ba64776cf69d22693c0a37e49d0303d84d8853c1c5ca937452
3
+ size 174488