skodan commited on
Commit
c33d894
·
1 Parent(s): 0f916ad

added project files

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .venv/
4
+ .ipynb_checkpoints/
5
+ .vscode/
6
+ *.pth
7
+ *.faiss
8
+ *.pkl
requirements.txt CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
 
1
  altair
2
- pandas
3
- streamlit
 
1
+ streamlit>=1.38.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.30.0
4
+ torch>=2.4.0
5
+ torchvision>=0.19.0
6
+ pillow>=10.0.0
7
+ huggingface_hub>=0.25.0
8
+ faiss-cpu>=1.8.0
9
+ pydantic>=2.0.0
10
+ numpy>=1.26.0
11
  altair
12
+ pandas
 
src/api.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api.py
2
+ from fastapi import FastAPI, UploadFile, File, Form
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from PIL import Image
5
+ from typing import List
6
+ from pydantic import BaseModel
7
+
8
+ from model_registry import get_model
9
+ from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery
10
+
11
+ app = FastAPI(title="Multimodal Retrieval & Captioning API")
12
+
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"],
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ class InferenceRequest(BaseModel):
21
+ model_name: str
22
+ top_k: int = 5
23
+
24
+ #@app.post("/caption", response_model=CaptionResult)
25
+ @app.post("/caption")
26
+ async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)):
27
+ image = Image.open(file.file).convert("RGB")
28
+ model = get_model(model_name)
29
+ caption = model.generate_caption(image)
30
+ return {"caption": caption}
31
+
32
+ #@app.post("/search/text2img", response_model=List[ImageResult])
33
+ @app.post("/search/text2img")
34
+ async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
35
+ model = get_model(model_name)
36
+ results = model.text_to_image(query, top_k)
37
+ return results
38
+
39
+ @app.post("/search/img2text")
40
+ async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
41
+ image = Image.open(file.file).convert("RGB")
42
+ model = get_model(model_name)
43
+ results = model.image_to_text(image, top_k)
44
+ return results
45
+
46
+ #@app.post("/search/img2img", response_model=List[ImageResult])
47
+ @app.post("/search/img2img")
48
+ async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
49
+ image = Image.open(file.file).convert("RGB")
50
+ model = get_model(model_name)
51
+ results = model.image_to_image(image, top_k)
52
+ return results
53
+
54
+ @app.get("/health")
55
+ def health_check():
56
+ return {"status": "healthy"}
src/app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ import requests
4
+ import subprocess
5
+ import time
6
+ from PIL import Image
7
+ import io
8
+ import base64 # For displaying retrieved images if needed
9
+
10
+ # Start FastAPI server in background
11
+ subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
12
+ time.sleep(2) # Wait for server to start
13
+
14
+ API_BASE = "http://localhost:8001"
15
+
16
+ st.set_page_config(page_title="Multimodal Retrieval & Captioning", layout="wide")
17
+
18
+ st.title("Multimodal Retrieval & Captioning System")
19
+
20
+ # Model selection (add more later)
21
+ model_name = st.sidebar.selectbox("Select Model", ["resnet_lstm_attention", "vit_lstm_attention", "vit_transformer"], index=0)
22
+
23
+ # Common inputs
24
+ input_method = st.sidebar.radio("Image Input", ["Upload", "Camera"])
25
+ image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if input_method == "Upload" else st.camera_input("Capture Image")
26
+ text_input = st.text_input("Text Input")
27
+ top_k = st.sidebar.slider("Top K", 1, 10, 5)
28
+
29
+ # Tabs for tasks
30
+ tab_caption, tab_text2img, tab_img2text, tab_img2img, tab_text2text = st.tabs([
31
+ "Image → Caption",
32
+ "Text → Image",
33
+ "Image → Text",
34
+ "Image → Image",
35
+ "Text → Text"
36
+ ])
37
+
38
+ with tab_caption:
39
+ if image_file and st.button("Generate Caption"):
40
+ files = {"file": image_file.getvalue()}
41
+ data = {"model_name": model_name}
42
+ resp = requests.post(f"{API_BASE}/caption", files=files, data=data)
43
+ if resp.status_code == 200:
44
+ st.write("Caption:", resp.json()["caption"])
45
+ else:
46
+ st.error("Error: " + resp.text)
47
+
48
+ with tab_text2img:
49
+ if text_input and st.button("Search Images"):
50
+ data = {"model_name": model_name, "query": text_input, "top_k": top_k}
51
+ resp = requests.post(f"{API_BASE}/search/text2img", data=data)
52
+ if resp.status_code == 200:
53
+ results = resp.json()
54
+ for res in results:
55
+ st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
56
+ else:
57
+ st.error("Error: " + resp.text)
58
+
59
+ with tab_img2text:
60
+ if image_file and st.button("Retrieve Text"):
61
+ files = {"file": image_file.getvalue()}
62
+ data = {"model_name": model_name, "top_k": top_k}
63
+ resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data)
64
+ if resp.status_code == 200:
65
+ st.write("Retrieved Texts:", resp.json())
66
+ else:
67
+ st.error("Error: " + resp.text)
68
+
69
+ with tab_img2img:
70
+ if image_file and st.button("Retrieve Similar Images"):
71
+ files = {"file": image_file.getvalue()}
72
+ data = {"model_name": model_name, "top_k": top_k}
73
+ resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
74
+ if resp.status_code == 200:
75
+ results = resp.json()
76
+ for res in results:
77
+ st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
78
+ else:
79
+ st.error("Error: " + resp.text)
80
+
81
+ with tab_text2text:
82
+ st.info("Text → Text not implemented yet. Add to model interface if needed.")
src/configs/caption_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_caption_len": 25,
3
+ "bos_token": "<start>",
4
+ "eos_token": "<end>",
5
+ "pad_token": "<pad>",
6
+ "unk_token": "<unk>",
7
+ "bos_index": 2,
8
+ "eos_index": 3,
9
+ "pad_index": 0,
10
+ "unk_index": 1,
11
+ "vocab_size": 2541,
12
+ "trained_epochs": 15,
13
+ "image_size": 224,
14
+ "mean": [
15
+ 0.485,
16
+ 0.456,
17
+ 0.406
18
+ ],
19
+ "std": [
20
+ 0.229,
21
+ 0.224,
22
+ 0.225
23
+ ]
24
+ }
src/configs/preprocess_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_size": 224,
3
+ "mean": [0.485, 0.456, 0.406],
4
+ "std": [0.229, 0.224, 0.225],
5
+ "embed_dim": 512,
6
+ "model_type": "resnet50_lstm_attention_retrieval",
7
+ "trained_epochs": 40,
8
+ "batch_size_train": 128,
9
+ "batch_size_eval": 256,
10
+ "learning_rate_image": 3e-5,
11
+ "learning_rate_text": 1e-4,
12
+ "weight_decay": 0.01,
13
+ "temperature": 0.07,
14
+ "scheduler": "ReduceLROnPlateau",
15
+ "scheduler_factor": 0.5,
16
+ "scheduler_patience": 3,
17
+ "unfrozen_layers": "layer3 and layer4",
18
+ "train_augmentation": [
19
+ "RandomResizedCrop(scale=(0.8,1.0))",
20
+ "RandomHorizontalFlip",
21
+ "ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)"
22
+ ],
23
+ "val_test_augmentation": "None (only resize + normalize)",
24
+ "dataset": "jxie/flickr8k",
25
+ "date_trained": "January 18, 2026",
26
+ "note": "Final configuration with all recommended improvements applied"
27
+ }
src/model_registry.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from utils.interfaces import UnifiedModelInterface
3
+
4
+ _LOADED_MODELS: Dict[str, UnifiedModelInterface] = {}
5
+
6
+ def get_model(model_name: str) -> UnifiedModelInterface:
7
+ if model_name in _LOADED_MODELS:
8
+ return _LOADED_MODELS[model_name]
9
+
10
+ if model_name == "resnet_lstm_attention":
11
+ from models.resnet_lstm_attention.model import ResNetLSTMAttentionModel
12
+ model = ResNetLSTMAttentionModel()
13
+ elif model_name == "vit_lstm_attention":
14
+ # Add later: from models.vit_lstm_attention.model import VitLSTMAttentionModel
15
+ # model = VitLSTMAttentionModel()
16
+ raise NotImplementedError("ViT + LSTM Attention not implemented yet")
17
+ elif model_name == "vit_transformer":
18
+ # Add later: from models.vit_transformer.model import VitTransformerModel
19
+ # model = VitTransformerModel()
20
+ raise NotImplementedError("ViT + Transformer not implemented yet")
21
+ else:
22
+ raise ValueError(f"Unknown model: {model_name}")
23
+
24
+ model.load()
25
+ _LOADED_MODELS[model_name] = model
26
+ return model
src/models/resnet_lstm_attention/__init__.py ADDED
File without changes
src/models/resnet_lstm_attention/cap_mod_defs.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import torchvision.transforms as transforms
6
+ import torchvision.models as models
7
+ from torchvision.models import resnet50, ResNet50_Weights
8
+
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ print(f"Using device: {device}")
12
+
13
+
14
+ class EncoderCNN(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
18
+ self.features = nn.Sequential(*list(resnet.children())[:-2])
19
+ for param in self.features.parameters():
20
+ param.requires_grad = False
21
+
22
+ def forward(self, images):
23
+ features = self.features(images) # (B, 2048, 7, 7)
24
+ features = features.permute(0,2,3,1) # (B,7,7,2048)
25
+ features = features.view(features.size(0), -1, 2048) # (B,49,2048)
26
+ return features
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(self, hidden_dim):
31
+ super().__init__()
32
+ self.W_h = nn.Linear(hidden_dim, hidden_dim, bias=False) # for decoder hidden
33
+ self.W_e = nn.Linear(2048, hidden_dim, bias=False) # for encoder features
34
+ self.v = nn.Linear(hidden_dim, 1, bias=False) # final score
35
+
36
+ def forward(self, hidden, encoder_outputs):
37
+ # Project and add
38
+ hidden_proj = self.W_h(hidden).unsqueeze(1) # (B, 1, hidden_dim)
39
+ encoder_proj = self.W_e(encoder_outputs) # (B, seq_len, hidden_dim)
40
+
41
+ # Compute energy: tanh(W_h * h + W_e * e)
42
+ energy = torch.tanh(hidden_proj + encoder_proj) # (B, seq_len, hidden_dim)
43
+
44
+ # Score: v^T * energy
45
+ scores = self.v(energy).squeeze(2) # (B, seq_len)
46
+
47
+ # Attention weights
48
+ attn_weights = torch.softmax(scores, dim=1) # (B, seq_len)
49
+
50
+ # Context vector: weighted sum of encoder outputs
51
+ context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
52
+ context = context.squeeze(1) # (B, 2048)
53
+
54
+ return context
55
+
56
+
57
+ class DecoderRNN(nn.Module):
58
+ def __init__(self, vocab_size, embed_size=256, hidden_size=512):
59
+ super().__init__()
60
+ self.embedding = nn.Embedding(vocab_size, embed_size)
61
+ self.attention = Attention(hidden_size)
62
+ self.lstm = nn.LSTMCell(embed_size + 2048, hidden_size)
63
+ self.fc = nn.Linear(hidden_size, vocab_size)
64
+
65
+ def forward(self, captions, encoder_outputs):
66
+ embedded = self.embedding(captions[:, :-1]) # (B, seq_len-1, embed_size)
67
+
68
+ h = torch.zeros(embedded.size(0), 512, device=device)
69
+ c = torch.zeros(embedded.size(0), 512, device=device)
70
+
71
+ outputs = []
72
+
73
+ for t in range(embedded.size(1)):
74
+ context = self.attention(h, encoder_outputs) # ← now correct
75
+ inp = torch.cat([embedded[:, t, :], context], dim=1)
76
+ h, c = self.lstm(inp, (h, c))
77
+ outputs.append(self.fc(h))
78
+
79
+ return torch.stack(outputs, dim=1) # (B, seq_len-1, vocab_size)
80
+
81
+ @torch.no_grad()
82
+ def generate(self, encoder_outputs, vocab, inv_vocab,max_len=25):
83
+ h = torch.zeros(1, 512, device=device)
84
+ c = torch.zeros(1, 512, device=device)
85
+
86
+ token = torch.tensor([[vocab['<start>']]], device=device)
87
+ caption = []
88
+
89
+ for _ in range(max_len):
90
+ emb = self.embedding(token)
91
+ context = self.attention(h, encoder_outputs)
92
+ inp = torch.cat([emb.squeeze(0), context], dim=1)
93
+ h, c = self.lstm(inp, (h, c))
94
+ pred = self.fc(h).argmax(1)
95
+ caption.append(pred.item())
96
+ token = pred.unsqueeze(0)
97
+
98
+ if pred.item() == vocab['<end>']:
99
+ break
100
+
101
+ return [inv_vocab.get(i, '<unk>') for i in caption if i != vocab['<end>']]
src/models/resnet_lstm_attention/captioning.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from .cap_mod_defs import EncoderCNN, DecoderRNN # reuse your exact classes
6
+
7
+ class CaptioningService:
8
+ def __init__(self, model_path, vocab_path, config):
9
+ self.device = torch.device("cpu")
10
+
11
+ # Load vocab
12
+ with open(vocab_path, "rb") as f:
13
+ self.vocab = pickle.load(f)
14
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
15
+
16
+ # Load checkpoint
17
+ ckpt = torch.load(model_path, map_location=self.device)
18
+
19
+ self.encoder = EncoderCNN().to(self.device)
20
+ self.encoder.load_state_dict(ckpt["encoder_state_dict"])
21
+ self.encoder.eval()
22
+
23
+ self.decoder = DecoderRNN(vocab_size=ckpt["vocab_size"]).to(self.device)
24
+ self.decoder.load_state_dict(ckpt["decoder_state_dict"])
25
+ self.decoder.eval()
26
+
27
+ self.max_len = config["max_length"]
28
+
29
+ self.transform = transforms.Compose([
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(
33
+ mean=[0.485, 0.456, 0.406],
34
+ std=[0.229, 0.224, 0.225]
35
+ )
36
+ ])
37
+
38
+ @torch.no_grad()
39
+ def generate_caption(self, image: Image.Image) -> str:
40
+ image = self.transform(image).unsqueeze(0).to(self.device)
41
+ features = self.encoder(image)
42
+ tokens = self.decoder.generate(
43
+ features,
44
+ vocab=self.vocab,
45
+ inv_vocab=self.inv_vocab,
46
+ max_len=self.max_len
47
+ )
48
+ return " ".join(tokens)
src/models/resnet_lstm_attention/clip_loader.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import text
2
+ import torch
3
+ from .ret_mod_defs import ImageEncoder, TextEncoder
4
+ from .utils import simple_tokenize
5
+
6
+
7
+ print("LOADED clip_loader.py - UPDATED VERSION with encode_text(text: str)")
8
+
9
+ class CLIPRetrievalModel:
10
+ """
11
+ Wrapper to expose a CLIP-like interface:
12
+ - encode_image(images)
13
+ - encode_text(captions, lengths)
14
+ """
15
+
16
+ def __init__(self, image_encoder, text_encoder, vocab, max_caption_len=30):
17
+ print("CLIPRetrievalModel initialized - single-arg encode_text version")
18
+ self.image_encoder = image_encoder
19
+ self.text_encoder = text_encoder
20
+ self.vocab = vocab
21
+ self.max_caption_len = max_caption_len
22
+
23
+ @torch.no_grad()
24
+ def encode_image(self, images):
25
+ return self.image_encoder(images)
26
+
27
+ @torch.no_grad()
28
+ def encode_text(self, text: str):
29
+ """
30
+ Accept raw text string → tokenize → pad → encode
31
+ No need to pass 'lengths' from outside anymore
32
+ """
33
+ # Tokenize using the same logic as training
34
+ words = [self.vocab.get(t, self.vocab['<unk>']) for t in simple_tokenize(text.lower())]
35
+ tokens = words[:self.max_caption_len]
36
+ tokens = [self.vocab['<start>']] + tokens + [self.vocab['<end>']]
37
+ length = len(tokens)
38
+ padded = tokens + [self.vocab['<pad>']] * (self.max_caption_len + 2 - length)
39
+
40
+ captions = torch.tensor([padded], dtype=torch.long).to(self.text_encoder.embedding.weight.device)
41
+ lengths = torch.tensor([length], dtype=torch.long).to(self.text_encoder.embedding.weight.device)
42
+
43
+ return self.text_encoder(captions, lengths)
44
+
45
+
46
+ def load_clip_model(model_path: str, vocab: dict, device: torch.device = torch.device("cpu")):
47
+ checkpoint = torch.load(model_path, map_location=device)
48
+
49
+ image_encoder = ImageEncoder(embed_dim=512).to(device)
50
+
51
+ # IMPORTANT: use vocab_size from checkpoint
52
+ text_encoder = TextEncoder(
53
+ vocab_size=checkpoint["vocab_size"]
54
+ ).to(device)
55
+
56
+ image_encoder.load_state_dict(checkpoint["image_encoder_state_dict"])
57
+ text_encoder.load_state_dict(checkpoint["text_encoder_state_dict"])
58
+
59
+ image_encoder.eval()
60
+ text_encoder.eval()
61
+
62
+ return CLIPRetrievalModel(image_encoder, text_encoder, vocab)
63
+
src/models/resnet_lstm_attention/loader.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ import json
4
+ from torchvision import transforms
5
+
6
+ from .cap_mod_defs import EncoderCNN, DecoderRNN
7
+
8
+
9
+ def load_captioning_model(
10
+ model_path: str,
11
+ vocab_path: str,
12
+ config_path: str,
13
+ device: torch.device = torch.device("cpu")
14
+ ):
15
+ # -----------------------------
16
+ # Load config
17
+ # -----------------------------
18
+ with open(config_path, "r") as f:
19
+ config = json.load(f)
20
+
21
+ max_len = config["max_caption_len"]
22
+
23
+ # -----------------------------
24
+ # Load vocab
25
+ # -----------------------------
26
+ with open(vocab_path, "rb") as f:
27
+ vocab_data = pickle.load(f)
28
+
29
+ # Handle both vocab formats safely
30
+ if isinstance(vocab_data, dict) and "vocab" in vocab_data:
31
+ vocab = vocab_data["vocab"]
32
+ inv_vocab = vocab_data.get("inv_vocab")
33
+ if inv_vocab is None:
34
+ inv_vocab = {v: k for k, v in vocab.items()}
35
+ else:
36
+ vocab = vocab_data
37
+ inv_vocab = {v: k for k, v in vocab.items()}
38
+
39
+
40
+ # -----------------------------
41
+ # Load checkpoint
42
+ # -----------------------------
43
+ checkpoint = torch.load(model_path, map_location=device)
44
+
45
+ # -----------------------------
46
+ # Recreate models
47
+ # -----------------------------
48
+ encoder = EncoderCNN().to(device)
49
+ decoder = DecoderRNN(vocab_size=checkpoint["vocab_size"]).to(device)
50
+
51
+ encoder.load_state_dict(checkpoint["encoder_state_dict"])
52
+ decoder.load_state_dict(checkpoint["decoder_state_dict"])
53
+
54
+ encoder.eval()
55
+ decoder.eval()
56
+
57
+ # -----------------------------
58
+ # Image preprocessing
59
+ # -----------------------------
60
+ transform = transforms.Compose([
61
+ transforms.Resize((224, 224)),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize(
64
+ mean=[0.485, 0.456, 0.406],
65
+ std=[0.229, 0.224, 0.225]
66
+ )
67
+ ])
68
+
69
+ # -----------------------------
70
+ # Return bundle
71
+ # -----------------------------
72
+ return {
73
+ "encoder": encoder,
74
+ "decoder": decoder,
75
+ "vocab": vocab,
76
+ "inv_vocab": inv_vocab,
77
+ "max_len": max_len,
78
+ "transform": transform
79
+ }
src/models/resnet_lstm_attention/model.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from huggingface_hub import hf_hub_download
5
+ from PIL import Image
6
+ import numpy as np
7
+ from typing import List, Dict, Any
8
+
9
+ from .loader import load_captioning_model
10
+ from .retrieval import RetrievalService
11
+ from .clip_loader import load_clip_model
12
+ from .captioning import CaptioningService # Not directly used, but for reference
13
+ from utils.interfaces import UnifiedModelInterface # Adjust path if needed
14
+
15
+ class ResNetLSTMAttentionModel(UnifiedModelInterface):
16
+ def __init__(self):
17
+ self.caption_bundle = None
18
+ self.retrieval_service = None
19
+ self.device = torch.device("cpu")
20
+ #self.model_repo = "skodan/resnet-lstm-attention-weights"
21
+
22
+ def load(self) -> None:
23
+ if self.caption_bundle is not None and self.retrieval_service is not None:
24
+ return
25
+
26
+ MODEL_REPO = "skodan/resnet-lstm-attention-weights"
27
+
28
+ files_to_download = [
29
+ "caption_model.pth",
30
+ "flickr8k_retrieval_model.pth",
31
+ "image_embeddings.faiss",
32
+ "text_embeddings.faiss",
33
+ "image_id_map.pkl",
34
+ "text_id_map.pkl",
35
+ "vocab.pkl"
36
+ ]
37
+
38
+ downloaded_paths = {}
39
+ for fname in files_to_download:
40
+ try:
41
+ path = hf_hub_download(
42
+ repo_id=MODEL_REPO,
43
+ filename=fname,
44
+ repo_type="model",
45
+ )
46
+ downloaded_paths[fname] = path
47
+ except Exception as e:
48
+ raise RuntimeError(f"Failed to download {fname} from {MODEL_REPO}: {e}")
49
+
50
+ # Download large files from HF Hub
51
+ caption_pth = downloaded_paths["caption_model.pth"]
52
+ retrieval_pth = downloaded_paths["flickr8k_retrieval_model.pth"]
53
+ image_index_faiss = downloaded_paths["image_embeddings.faiss"]
54
+ text_index_faiss = downloaded_paths["text_embeddings.faiss"]
55
+ image_map_pkl = downloaded_paths["image_id_map.pkl"]
56
+ text_map_pkl = downloaded_paths["text_id_map.pkl"]
57
+ vocab_pkl = downloaded_paths["vocab.pkl"]
58
+
59
+ # Load configs (assume small, committed to repo)
60
+ base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) # go up to project root
61
+ config_path = os.path.join(base_dir, "configs", "caption_config.json")
62
+ preprocess_cfg_path = os.path.join(base_dir, "configs", "preprocess_config.json")
63
+
64
+ with open(config_path, "r") as f:
65
+ caption_config = json.load(f)
66
+
67
+ with open(preprocess_cfg_path, "r") as f:
68
+ preprocess_cfg = json.load(f)
69
+
70
+ # Load captioning
71
+ self.caption_bundle = load_captioning_model(
72
+ model_path=caption_pth,
73
+ vocab_path=vocab_pkl,
74
+ config_path=config_path,
75
+ device=self.device
76
+ )
77
+
78
+ # Load retrieval
79
+ clip_model = load_clip_model(
80
+ model_path=retrieval_pth,
81
+ vocab=self.caption_bundle["vocab"],
82
+ device=self.device
83
+ )
84
+
85
+ self.retrieval_service = RetrievalService(
86
+ clip_model=clip_model,
87
+ image_index_path=image_index_faiss,
88
+ text_index_path=text_index_faiss,
89
+ image_map_path=image_map_pkl,
90
+ text_map_path=text_map_pkl,
91
+ preprocess=preprocess_cfg
92
+ )
93
+
94
+ print("Model components loaded successfully.")
95
+
96
+ @torch.no_grad()
97
+ def generate_caption(self, image: Image.Image) -> str:
98
+ encoder = self.caption_bundle["encoder"]
99
+ decoder = self.caption_bundle["decoder"]
100
+ vocab = self.caption_bundle["vocab"]
101
+ inv_vocab = self.caption_bundle["inv_vocab"]
102
+ max_len = self.caption_bundle["max_len"]
103
+ transform = self.caption_bundle["transform"]
104
+
105
+ image_tensor = transform(image).unsqueeze(0).to(self.device)
106
+ features = encoder(image_tensor)
107
+ tokens = decoder.generate(
108
+ features,
109
+ vocab=vocab,
110
+ inv_vocab=inv_vocab,
111
+ max_len=max_len
112
+ )
113
+ return " ".join(tokens)
114
+
115
+ def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
116
+ return self.retrieval_service.text_to_image(text, top_k)
117
+
118
+ def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
119
+ return self.retrieval_service.image_to_text(image, top_k)
120
+
121
+ def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
122
+ image_tensor = self.retrieval_service.image_transform(image).unsqueeze(0).to(self.device)
123
+ with torch.no_grad():
124
+ emb = self.retrieval_service.clip_model.encode_image(image_tensor).cpu().numpy()
125
+ emb = self.retrieval_service._normalize(emb)
126
+ scores, idxs = self.retrieval_service.image_index.search(emb, top_k)
127
+ return [
128
+ {"image_path": self.retrieval_service.image_id_map[i], "score": float(scores[0][j])}
129
+ for j, i in enumerate(idxs[0])
130
+ ]
src/models/resnet_lstm_attention/ret_mod_defs.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import torchvision.transforms as transforms
7
+ import torchvision.models as models
8
+ from torchvision.models import resnet50, ResNet50_Weights
9
+
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ print(f"Using device: {device}")
13
+
14
+
15
+ class ImageEncoder(nn.Module):
16
+ def __init__(self, embed_dim=512):
17
+ super().__init__()
18
+ resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
19
+ self.backbone = nn.Sequential(*list(resnet.children())[:-2])
20
+ self.pool = nn.AdaptiveAvgPool2d(1)
21
+ self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
22
+
23
+ # Freeze early layers (up to layer2)
24
+ freeze_until = 7 # Unfreezed more layers to improve the retrieval values
25
+ for child in list(self.backbone.children())[:freeze_until]:
26
+ for p in child.parameters():
27
+ p.requires_grad = False
28
+
29
+ print("ImageEncoder: layers 0-6 frozen, layer3+layer4 trainable.")
30
+
31
+ def forward(self, x):
32
+ feat = self.backbone(x)
33
+ feat = self.pool(feat)
34
+ feat = torch.flatten(feat, 1)
35
+ emb = self.fc(feat)
36
+ return F.normalize(emb, p=2, dim=1)
37
+
38
+
39
+ class TextEncoder(nn.Module):
40
+ def __init__(self, vocab_size, embed_dim=300, hidden_dim=512, out_dim=512):
41
+ super().__init__()
42
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
43
+ self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
44
+ self.attn = nn.Linear(hidden_dim, 1)
45
+ self.fc = nn.Linear(hidden_dim, out_dim)
46
+
47
+ def forward(self, captions, lengths):
48
+ embedded = self.embedding(captions)
49
+ packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
50
+ lstm_out, _ = self.lstm(packed)
51
+ lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
52
+
53
+ attn_w = F.softmax(self.attn(lstm_out), dim=1)
54
+ context = torch.sum(attn_w * lstm_out, dim=1)
55
+
56
+ emb = self.fc(context)
57
+ return F.normalize(emb, p=2, dim=1)
src/models/resnet_lstm_attention/retrieval.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import pickle
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+ class RetrievalService:
9
+ def __init__(self, clip_model, image_index_path, text_index_path,
10
+ image_map_path, text_map_path, preprocess):
11
+
12
+ self.device = torch.device("cpu")
13
+ self.clip_model = clip_model
14
+
15
+ self.image_index = faiss.read_index(image_index_path)
16
+ self.text_index = faiss.read_index(text_index_path)
17
+
18
+ with open(image_map_path, "rb") as f:
19
+ self.image_id_map = pickle.load(f)
20
+
21
+ with open(text_map_path, "rb") as f:
22
+ self.text_id_map = pickle.load(f)
23
+
24
+ self.image_transform = transforms.Compose([
25
+ transforms.Resize((224, 224)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(
28
+ mean=preprocess["mean"],
29
+ std=preprocess["std"]
30
+ )
31
+ ])
32
+
33
+ def _normalize(self, x):
34
+ return x / np.linalg.norm(x, axis=1, keepdims=True)
35
+
36
+ def text_to_image(self, text, top_k=5):
37
+ with torch.no_grad():
38
+ emb = self.clip_model.encode_text(text).cpu().numpy()
39
+ emb = self._normalize(emb)
40
+
41
+ scores, idxs = self.image_index.search(emb, top_k)
42
+ return [
43
+ {
44
+ "image_path": self.image_id_map[i],
45
+ "score": float(scores[0][j])
46
+ }
47
+ for j, i in enumerate(idxs[0])
48
+ ]
49
+
50
+ def image_to_text(self, image: Image.Image, top_k=5):
51
+ image = self.image_transform(image).unsqueeze(0)
52
+ with torch.no_grad():
53
+ emb = self.clip_model.encode_image(image).cpu().numpy()
54
+ emb = self._normalize(emb)
55
+
56
+ scores, idxs = self.text_index.search(emb, top_k)
57
+ results = [self.text_id_map[i] for i in idxs[0]]
58
+ print(f"DEBUG: Returning results: {results}")
59
+ return results
src/models/resnet_lstm_attention/schemas.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List
3
+
4
+ class TextQuery(BaseModel):
5
+ query: str
6
+ top_k: int = 5
7
+
8
+ class ImageResult(BaseModel):
9
+ image_path: str
10
+ score: float
11
+
12
+ class CaptionResult(BaseModel):
13
+ caption: str
src/models/resnet_lstm_attention/utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # backend/utils.py
2
+ import re
3
+
4
+ def simple_tokenize(text: str) -> list[str]:
5
+ """
6
+ Same simple tokenizer used during training
7
+ """
8
+ text = re.sub(r'[^\w\s]', '', text.lower())
9
+ text = re.sub(r'\s+', ' ', text).strip()
10
+ return text.split()
src/{streamlit_app.py → streamlit_app_old.py} RENAMED
File without changes
src/utils/interfaces.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Any
3
+ from PIL import Image
4
+
5
+ class UnifiedModelInterface(ABC):
6
+ @abstractmethod
7
+ def load(self) -> None:
8
+ """Lazy load all required components (models, indices, vocab, etc.)"""
9
+ pass
10
+
11
+ @abstractmethod
12
+ def generate_caption(self, image: Image.Image) -> str:
13
+ pass
14
+
15
+ @abstractmethod
16
+ def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
17
+ """Returns [{'image_path': str, 'score': float}, ...]"""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
22
+ """Returns list of caption strings"""
23
+ pass
24
+
25
+ @abstractmethod
26
+ def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
27
+ pass
28
+
29
+ def text_to_text(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
30
+ raise NotImplementedError("Text-to-text not supported by this model")