Spaces:
Sleeping
Sleeping
added project files
Browse files- .gitignore +8 -0
- requirements.txt +11 -2
- src/api.py +56 -0
- src/app.py +82 -0
- src/configs/caption_config.json +24 -0
- src/configs/preprocess_config.json +27 -0
- src/model_registry.py +26 -0
- src/models/resnet_lstm_attention/__init__.py +0 -0
- src/models/resnet_lstm_attention/cap_mod_defs.py +101 -0
- src/models/resnet_lstm_attention/captioning.py +48 -0
- src/models/resnet_lstm_attention/clip_loader.py +63 -0
- src/models/resnet_lstm_attention/loader.py +79 -0
- src/models/resnet_lstm_attention/model.py +130 -0
- src/models/resnet_lstm_attention/ret_mod_defs.py +57 -0
- src/models/resnet_lstm_attention/retrieval.py +59 -0
- src/models/resnet_lstm_attention/schemas.py +13 -0
- src/models/resnet_lstm_attention/utils.py +10 -0
- src/{streamlit_app.py → streamlit_app_old.py} +0 -0
- src/utils/interfaces.py +30 -0
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.venv/
|
| 4 |
+
.ipynb_checkpoints/
|
| 5 |
+
.vscode/
|
| 6 |
+
*.pth
|
| 7 |
+
*.faiss
|
| 8 |
+
*.pkl
|
requirements.txt
CHANGED
|
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
altair
|
| 2 |
-
pandas
|
| 3 |
-
streamlit
|
|
|
|
| 1 |
+
streamlit>=1.38.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.30.0
|
| 4 |
+
torch>=2.4.0
|
| 5 |
+
torchvision>=0.19.0
|
| 6 |
+
pillow>=10.0.0
|
| 7 |
+
huggingface_hub>=0.25.0
|
| 8 |
+
faiss-cpu>=1.8.0
|
| 9 |
+
pydantic>=2.0.0
|
| 10 |
+
numpy>=1.26.0
|
| 11 |
altair
|
| 12 |
+
pandas
|
|
|
src/api.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api.py
|
| 2 |
+
from fastapi import FastAPI, UploadFile, File, Form
|
| 3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from typing import List
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
from model_registry import get_model
|
| 9 |
+
from models.resnet_lstm_attention.schemas import CaptionResult, ImageResult, TextQuery
|
| 10 |
+
|
| 11 |
+
app = FastAPI(title="Multimodal Retrieval & Captioning API")
|
| 12 |
+
|
| 13 |
+
app.add_middleware(
|
| 14 |
+
CORSMiddleware,
|
| 15 |
+
allow_origins=["*"],
|
| 16 |
+
allow_methods=["*"],
|
| 17 |
+
allow_headers=["*"],
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
class InferenceRequest(BaseModel):
|
| 21 |
+
model_name: str
|
| 22 |
+
top_k: int = 5
|
| 23 |
+
|
| 24 |
+
#@app.post("/caption", response_model=CaptionResult)
|
| 25 |
+
@app.post("/caption")
|
| 26 |
+
async def caption_image(model_name: str = Form(...), file: UploadFile = File(...)):
|
| 27 |
+
image = Image.open(file.file).convert("RGB")
|
| 28 |
+
model = get_model(model_name)
|
| 29 |
+
caption = model.generate_caption(image)
|
| 30 |
+
return {"caption": caption}
|
| 31 |
+
|
| 32 |
+
#@app.post("/search/text2img", response_model=List[ImageResult])
|
| 33 |
+
@app.post("/search/text2img")
|
| 34 |
+
async def text_to_image(model_name: str = Form(...), query: str = Form(...), top_k: int = Form(5)):
|
| 35 |
+
model = get_model(model_name)
|
| 36 |
+
results = model.text_to_image(query, top_k)
|
| 37 |
+
return results
|
| 38 |
+
|
| 39 |
+
@app.post("/search/img2text")
|
| 40 |
+
async def image_to_text(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
|
| 41 |
+
image = Image.open(file.file).convert("RGB")
|
| 42 |
+
model = get_model(model_name)
|
| 43 |
+
results = model.image_to_text(image, top_k)
|
| 44 |
+
return results
|
| 45 |
+
|
| 46 |
+
#@app.post("/search/img2img", response_model=List[ImageResult])
|
| 47 |
+
@app.post("/search/img2img")
|
| 48 |
+
async def image_to_image(model_name: str = Form(...), file: UploadFile = File(...), top_k: int = Form(5)):
|
| 49 |
+
image = Image.open(file.file).convert("RGB")
|
| 50 |
+
model = get_model(model_name)
|
| 51 |
+
results = model.image_to_image(image, top_k)
|
| 52 |
+
return results
|
| 53 |
+
|
| 54 |
+
@app.get("/health")
|
| 55 |
+
def health_check():
|
| 56 |
+
return {"status": "healthy"}
|
src/app.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import requests
|
| 4 |
+
import subprocess
|
| 5 |
+
import time
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import io
|
| 8 |
+
import base64 # For displaying retrieved images if needed
|
| 9 |
+
|
| 10 |
+
# Start FastAPI server in background
|
| 11 |
+
subprocess.Popen(["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8001"])
|
| 12 |
+
time.sleep(2) # Wait for server to start
|
| 13 |
+
|
| 14 |
+
API_BASE = "http://localhost:8001"
|
| 15 |
+
|
| 16 |
+
st.set_page_config(page_title="Multimodal Retrieval & Captioning", layout="wide")
|
| 17 |
+
|
| 18 |
+
st.title("Multimodal Retrieval & Captioning System")
|
| 19 |
+
|
| 20 |
+
# Model selection (add more later)
|
| 21 |
+
model_name = st.sidebar.selectbox("Select Model", ["resnet_lstm_attention", "vit_lstm_attention", "vit_transformer"], index=0)
|
| 22 |
+
|
| 23 |
+
# Common inputs
|
| 24 |
+
input_method = st.sidebar.radio("Image Input", ["Upload", "Camera"])
|
| 25 |
+
image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) if input_method == "Upload" else st.camera_input("Capture Image")
|
| 26 |
+
text_input = st.text_input("Text Input")
|
| 27 |
+
top_k = st.sidebar.slider("Top K", 1, 10, 5)
|
| 28 |
+
|
| 29 |
+
# Tabs for tasks
|
| 30 |
+
tab_caption, tab_text2img, tab_img2text, tab_img2img, tab_text2text = st.tabs([
|
| 31 |
+
"Image → Caption",
|
| 32 |
+
"Text → Image",
|
| 33 |
+
"Image → Text",
|
| 34 |
+
"Image → Image",
|
| 35 |
+
"Text → Text"
|
| 36 |
+
])
|
| 37 |
+
|
| 38 |
+
with tab_caption:
|
| 39 |
+
if image_file and st.button("Generate Caption"):
|
| 40 |
+
files = {"file": image_file.getvalue()}
|
| 41 |
+
data = {"model_name": model_name}
|
| 42 |
+
resp = requests.post(f"{API_BASE}/caption", files=files, data=data)
|
| 43 |
+
if resp.status_code == 200:
|
| 44 |
+
st.write("Caption:", resp.json()["caption"])
|
| 45 |
+
else:
|
| 46 |
+
st.error("Error: " + resp.text)
|
| 47 |
+
|
| 48 |
+
with tab_text2img:
|
| 49 |
+
if text_input and st.button("Search Images"):
|
| 50 |
+
data = {"model_name": model_name, "query": text_input, "top_k": top_k}
|
| 51 |
+
resp = requests.post(f"{API_BASE}/search/text2img", data=data)
|
| 52 |
+
if resp.status_code == 200:
|
| 53 |
+
results = resp.json()
|
| 54 |
+
for res in results:
|
| 55 |
+
st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
|
| 56 |
+
else:
|
| 57 |
+
st.error("Error: " + resp.text)
|
| 58 |
+
|
| 59 |
+
with tab_img2text:
|
| 60 |
+
if image_file and st.button("Retrieve Text"):
|
| 61 |
+
files = {"file": image_file.getvalue()}
|
| 62 |
+
data = {"model_name": model_name, "top_k": top_k}
|
| 63 |
+
resp = requests.post(f"{API_BASE}/search/img2text", files=files, data=data)
|
| 64 |
+
if resp.status_code == 200:
|
| 65 |
+
st.write("Retrieved Texts:", resp.json())
|
| 66 |
+
else:
|
| 67 |
+
st.error("Error: " + resp.text)
|
| 68 |
+
|
| 69 |
+
with tab_img2img:
|
| 70 |
+
if image_file and st.button("Retrieve Similar Images"):
|
| 71 |
+
files = {"file": image_file.getvalue()}
|
| 72 |
+
data = {"model_name": model_name, "top_k": top_k}
|
| 73 |
+
resp = requests.post(f"{API_BASE}/search/img2img", files=files, data=data)
|
| 74 |
+
if resp.status_code == 200:
|
| 75 |
+
results = resp.json()
|
| 76 |
+
for res in results:
|
| 77 |
+
st.image(res["image_path"], caption=f"Score: {res['score']:.3f}")
|
| 78 |
+
else:
|
| 79 |
+
st.error("Error: " + resp.text)
|
| 80 |
+
|
| 81 |
+
with tab_text2text:
|
| 82 |
+
st.info("Text → Text not implemented yet. Add to model interface if needed.")
|
src/configs/caption_config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_caption_len": 25,
|
| 3 |
+
"bos_token": "<start>",
|
| 4 |
+
"eos_token": "<end>",
|
| 5 |
+
"pad_token": "<pad>",
|
| 6 |
+
"unk_token": "<unk>",
|
| 7 |
+
"bos_index": 2,
|
| 8 |
+
"eos_index": 3,
|
| 9 |
+
"pad_index": 0,
|
| 10 |
+
"unk_index": 1,
|
| 11 |
+
"vocab_size": 2541,
|
| 12 |
+
"trained_epochs": 15,
|
| 13 |
+
"image_size": 224,
|
| 14 |
+
"mean": [
|
| 15 |
+
0.485,
|
| 16 |
+
0.456,
|
| 17 |
+
0.406
|
| 18 |
+
],
|
| 19 |
+
"std": [
|
| 20 |
+
0.229,
|
| 21 |
+
0.224,
|
| 22 |
+
0.225
|
| 23 |
+
]
|
| 24 |
+
}
|
src/configs/preprocess_config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"image_size": 224,
|
| 3 |
+
"mean": [0.485, 0.456, 0.406],
|
| 4 |
+
"std": [0.229, 0.224, 0.225],
|
| 5 |
+
"embed_dim": 512,
|
| 6 |
+
"model_type": "resnet50_lstm_attention_retrieval",
|
| 7 |
+
"trained_epochs": 40,
|
| 8 |
+
"batch_size_train": 128,
|
| 9 |
+
"batch_size_eval": 256,
|
| 10 |
+
"learning_rate_image": 3e-5,
|
| 11 |
+
"learning_rate_text": 1e-4,
|
| 12 |
+
"weight_decay": 0.01,
|
| 13 |
+
"temperature": 0.07,
|
| 14 |
+
"scheduler": "ReduceLROnPlateau",
|
| 15 |
+
"scheduler_factor": 0.5,
|
| 16 |
+
"scheduler_patience": 3,
|
| 17 |
+
"unfrozen_layers": "layer3 and layer4",
|
| 18 |
+
"train_augmentation": [
|
| 19 |
+
"RandomResizedCrop(scale=(0.8,1.0))",
|
| 20 |
+
"RandomHorizontalFlip",
|
| 21 |
+
"ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)"
|
| 22 |
+
],
|
| 23 |
+
"val_test_augmentation": "None (only resize + normalize)",
|
| 24 |
+
"dataset": "jxie/flickr8k",
|
| 25 |
+
"date_trained": "January 18, 2026",
|
| 26 |
+
"note": "Final configuration with all recommended improvements applied"
|
| 27 |
+
}
|
src/model_registry.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
from utils.interfaces import UnifiedModelInterface
|
| 3 |
+
|
| 4 |
+
_LOADED_MODELS: Dict[str, UnifiedModelInterface] = {}
|
| 5 |
+
|
| 6 |
+
def get_model(model_name: str) -> UnifiedModelInterface:
|
| 7 |
+
if model_name in _LOADED_MODELS:
|
| 8 |
+
return _LOADED_MODELS[model_name]
|
| 9 |
+
|
| 10 |
+
if model_name == "resnet_lstm_attention":
|
| 11 |
+
from models.resnet_lstm_attention.model import ResNetLSTMAttentionModel
|
| 12 |
+
model = ResNetLSTMAttentionModel()
|
| 13 |
+
elif model_name == "vit_lstm_attention":
|
| 14 |
+
# Add later: from models.vit_lstm_attention.model import VitLSTMAttentionModel
|
| 15 |
+
# model = VitLSTMAttentionModel()
|
| 16 |
+
raise NotImplementedError("ViT + LSTM Attention not implemented yet")
|
| 17 |
+
elif model_name == "vit_transformer":
|
| 18 |
+
# Add later: from models.vit_transformer.model import VitTransformerModel
|
| 19 |
+
# model = VitTransformerModel()
|
| 20 |
+
raise NotImplementedError("ViT + Transformer not implemented yet")
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError(f"Unknown model: {model_name}")
|
| 23 |
+
|
| 24 |
+
model.load()
|
| 25 |
+
_LOADED_MODELS[model_name] = model
|
| 26 |
+
return model
|
src/models/resnet_lstm_attention/__init__.py
ADDED
|
File without changes
|
src/models/resnet_lstm_attention/cap_mod_defs.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
import torchvision.models as models
|
| 7 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 11 |
+
print(f"Using device: {device}")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EncoderCNN(nn.Module):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super().__init__()
|
| 17 |
+
resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
|
| 18 |
+
self.features = nn.Sequential(*list(resnet.children())[:-2])
|
| 19 |
+
for param in self.features.parameters():
|
| 20 |
+
param.requires_grad = False
|
| 21 |
+
|
| 22 |
+
def forward(self, images):
|
| 23 |
+
features = self.features(images) # (B, 2048, 7, 7)
|
| 24 |
+
features = features.permute(0,2,3,1) # (B,7,7,2048)
|
| 25 |
+
features = features.view(features.size(0), -1, 2048) # (B,49,2048)
|
| 26 |
+
return features
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Attention(nn.Module):
|
| 30 |
+
def __init__(self, hidden_dim):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.W_h = nn.Linear(hidden_dim, hidden_dim, bias=False) # for decoder hidden
|
| 33 |
+
self.W_e = nn.Linear(2048, hidden_dim, bias=False) # for encoder features
|
| 34 |
+
self.v = nn.Linear(hidden_dim, 1, bias=False) # final score
|
| 35 |
+
|
| 36 |
+
def forward(self, hidden, encoder_outputs):
|
| 37 |
+
# Project and add
|
| 38 |
+
hidden_proj = self.W_h(hidden).unsqueeze(1) # (B, 1, hidden_dim)
|
| 39 |
+
encoder_proj = self.W_e(encoder_outputs) # (B, seq_len, hidden_dim)
|
| 40 |
+
|
| 41 |
+
# Compute energy: tanh(W_h * h + W_e * e)
|
| 42 |
+
energy = torch.tanh(hidden_proj + encoder_proj) # (B, seq_len, hidden_dim)
|
| 43 |
+
|
| 44 |
+
# Score: v^T * energy
|
| 45 |
+
scores = self.v(energy).squeeze(2) # (B, seq_len)
|
| 46 |
+
|
| 47 |
+
# Attention weights
|
| 48 |
+
attn_weights = torch.softmax(scores, dim=1) # (B, seq_len)
|
| 49 |
+
|
| 50 |
+
# Context vector: weighted sum of encoder outputs
|
| 51 |
+
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
|
| 52 |
+
context = context.squeeze(1) # (B, 2048)
|
| 53 |
+
|
| 54 |
+
return context
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class DecoderRNN(nn.Module):
|
| 58 |
+
def __init__(self, vocab_size, embed_size=256, hidden_size=512):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.embedding = nn.Embedding(vocab_size, embed_size)
|
| 61 |
+
self.attention = Attention(hidden_size)
|
| 62 |
+
self.lstm = nn.LSTMCell(embed_size + 2048, hidden_size)
|
| 63 |
+
self.fc = nn.Linear(hidden_size, vocab_size)
|
| 64 |
+
|
| 65 |
+
def forward(self, captions, encoder_outputs):
|
| 66 |
+
embedded = self.embedding(captions[:, :-1]) # (B, seq_len-1, embed_size)
|
| 67 |
+
|
| 68 |
+
h = torch.zeros(embedded.size(0), 512, device=device)
|
| 69 |
+
c = torch.zeros(embedded.size(0), 512, device=device)
|
| 70 |
+
|
| 71 |
+
outputs = []
|
| 72 |
+
|
| 73 |
+
for t in range(embedded.size(1)):
|
| 74 |
+
context = self.attention(h, encoder_outputs) # ← now correct
|
| 75 |
+
inp = torch.cat([embedded[:, t, :], context], dim=1)
|
| 76 |
+
h, c = self.lstm(inp, (h, c))
|
| 77 |
+
outputs.append(self.fc(h))
|
| 78 |
+
|
| 79 |
+
return torch.stack(outputs, dim=1) # (B, seq_len-1, vocab_size)
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def generate(self, encoder_outputs, vocab, inv_vocab,max_len=25):
|
| 83 |
+
h = torch.zeros(1, 512, device=device)
|
| 84 |
+
c = torch.zeros(1, 512, device=device)
|
| 85 |
+
|
| 86 |
+
token = torch.tensor([[vocab['<start>']]], device=device)
|
| 87 |
+
caption = []
|
| 88 |
+
|
| 89 |
+
for _ in range(max_len):
|
| 90 |
+
emb = self.embedding(token)
|
| 91 |
+
context = self.attention(h, encoder_outputs)
|
| 92 |
+
inp = torch.cat([emb.squeeze(0), context], dim=1)
|
| 93 |
+
h, c = self.lstm(inp, (h, c))
|
| 94 |
+
pred = self.fc(h).argmax(1)
|
| 95 |
+
caption.append(pred.item())
|
| 96 |
+
token = pred.unsqueeze(0)
|
| 97 |
+
|
| 98 |
+
if pred.item() == vocab['<end>']:
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
return [inv_vocab.get(i, '<unk>') for i in caption if i != vocab['<end>']]
|
src/models/resnet_lstm_attention/captioning.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pickle
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from .cap_mod_defs import EncoderCNN, DecoderRNN # reuse your exact classes
|
| 6 |
+
|
| 7 |
+
class CaptioningService:
|
| 8 |
+
def __init__(self, model_path, vocab_path, config):
|
| 9 |
+
self.device = torch.device("cpu")
|
| 10 |
+
|
| 11 |
+
# Load vocab
|
| 12 |
+
with open(vocab_path, "rb") as f:
|
| 13 |
+
self.vocab = pickle.load(f)
|
| 14 |
+
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
| 15 |
+
|
| 16 |
+
# Load checkpoint
|
| 17 |
+
ckpt = torch.load(model_path, map_location=self.device)
|
| 18 |
+
|
| 19 |
+
self.encoder = EncoderCNN().to(self.device)
|
| 20 |
+
self.encoder.load_state_dict(ckpt["encoder_state_dict"])
|
| 21 |
+
self.encoder.eval()
|
| 22 |
+
|
| 23 |
+
self.decoder = DecoderRNN(vocab_size=ckpt["vocab_size"]).to(self.device)
|
| 24 |
+
self.decoder.load_state_dict(ckpt["decoder_state_dict"])
|
| 25 |
+
self.decoder.eval()
|
| 26 |
+
|
| 27 |
+
self.max_len = config["max_length"]
|
| 28 |
+
|
| 29 |
+
self.transform = transforms.Compose([
|
| 30 |
+
transforms.Resize((224, 224)),
|
| 31 |
+
transforms.ToTensor(),
|
| 32 |
+
transforms.Normalize(
|
| 33 |
+
mean=[0.485, 0.456, 0.406],
|
| 34 |
+
std=[0.229, 0.224, 0.225]
|
| 35 |
+
)
|
| 36 |
+
])
|
| 37 |
+
|
| 38 |
+
@torch.no_grad()
|
| 39 |
+
def generate_caption(self, image: Image.Image) -> str:
|
| 40 |
+
image = self.transform(image).unsqueeze(0).to(self.device)
|
| 41 |
+
features = self.encoder(image)
|
| 42 |
+
tokens = self.decoder.generate(
|
| 43 |
+
features,
|
| 44 |
+
vocab=self.vocab,
|
| 45 |
+
inv_vocab=self.inv_vocab,
|
| 46 |
+
max_len=self.max_len
|
| 47 |
+
)
|
| 48 |
+
return " ".join(tokens)
|
src/models/resnet_lstm_attention/clip_loader.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from matplotlib import text
|
| 2 |
+
import torch
|
| 3 |
+
from .ret_mod_defs import ImageEncoder, TextEncoder
|
| 4 |
+
from .utils import simple_tokenize
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
print("LOADED clip_loader.py - UPDATED VERSION with encode_text(text: str)")
|
| 8 |
+
|
| 9 |
+
class CLIPRetrievalModel:
|
| 10 |
+
"""
|
| 11 |
+
Wrapper to expose a CLIP-like interface:
|
| 12 |
+
- encode_image(images)
|
| 13 |
+
- encode_text(captions, lengths)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, image_encoder, text_encoder, vocab, max_caption_len=30):
|
| 17 |
+
print("CLIPRetrievalModel initialized - single-arg encode_text version")
|
| 18 |
+
self.image_encoder = image_encoder
|
| 19 |
+
self.text_encoder = text_encoder
|
| 20 |
+
self.vocab = vocab
|
| 21 |
+
self.max_caption_len = max_caption_len
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def encode_image(self, images):
|
| 25 |
+
return self.image_encoder(images)
|
| 26 |
+
|
| 27 |
+
@torch.no_grad()
|
| 28 |
+
def encode_text(self, text: str):
|
| 29 |
+
"""
|
| 30 |
+
Accept raw text string → tokenize → pad → encode
|
| 31 |
+
No need to pass 'lengths' from outside anymore
|
| 32 |
+
"""
|
| 33 |
+
# Tokenize using the same logic as training
|
| 34 |
+
words = [self.vocab.get(t, self.vocab['<unk>']) for t in simple_tokenize(text.lower())]
|
| 35 |
+
tokens = words[:self.max_caption_len]
|
| 36 |
+
tokens = [self.vocab['<start>']] + tokens + [self.vocab['<end>']]
|
| 37 |
+
length = len(tokens)
|
| 38 |
+
padded = tokens + [self.vocab['<pad>']] * (self.max_caption_len + 2 - length)
|
| 39 |
+
|
| 40 |
+
captions = torch.tensor([padded], dtype=torch.long).to(self.text_encoder.embedding.weight.device)
|
| 41 |
+
lengths = torch.tensor([length], dtype=torch.long).to(self.text_encoder.embedding.weight.device)
|
| 42 |
+
|
| 43 |
+
return self.text_encoder(captions, lengths)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_clip_model(model_path: str, vocab: dict, device: torch.device = torch.device("cpu")):
|
| 47 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 48 |
+
|
| 49 |
+
image_encoder = ImageEncoder(embed_dim=512).to(device)
|
| 50 |
+
|
| 51 |
+
# IMPORTANT: use vocab_size from checkpoint
|
| 52 |
+
text_encoder = TextEncoder(
|
| 53 |
+
vocab_size=checkpoint["vocab_size"]
|
| 54 |
+
).to(device)
|
| 55 |
+
|
| 56 |
+
image_encoder.load_state_dict(checkpoint["image_encoder_state_dict"])
|
| 57 |
+
text_encoder.load_state_dict(checkpoint["text_encoder_state_dict"])
|
| 58 |
+
|
| 59 |
+
image_encoder.eval()
|
| 60 |
+
text_encoder.eval()
|
| 61 |
+
|
| 62 |
+
return CLIPRetrievalModel(image_encoder, text_encoder, vocab)
|
| 63 |
+
|
src/models/resnet_lstm_attention/loader.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pickle
|
| 3 |
+
import json
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
|
| 6 |
+
from .cap_mod_defs import EncoderCNN, DecoderRNN
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_captioning_model(
|
| 10 |
+
model_path: str,
|
| 11 |
+
vocab_path: str,
|
| 12 |
+
config_path: str,
|
| 13 |
+
device: torch.device = torch.device("cpu")
|
| 14 |
+
):
|
| 15 |
+
# -----------------------------
|
| 16 |
+
# Load config
|
| 17 |
+
# -----------------------------
|
| 18 |
+
with open(config_path, "r") as f:
|
| 19 |
+
config = json.load(f)
|
| 20 |
+
|
| 21 |
+
max_len = config["max_caption_len"]
|
| 22 |
+
|
| 23 |
+
# -----------------------------
|
| 24 |
+
# Load vocab
|
| 25 |
+
# -----------------------------
|
| 26 |
+
with open(vocab_path, "rb") as f:
|
| 27 |
+
vocab_data = pickle.load(f)
|
| 28 |
+
|
| 29 |
+
# Handle both vocab formats safely
|
| 30 |
+
if isinstance(vocab_data, dict) and "vocab" in vocab_data:
|
| 31 |
+
vocab = vocab_data["vocab"]
|
| 32 |
+
inv_vocab = vocab_data.get("inv_vocab")
|
| 33 |
+
if inv_vocab is None:
|
| 34 |
+
inv_vocab = {v: k for k, v in vocab.items()}
|
| 35 |
+
else:
|
| 36 |
+
vocab = vocab_data
|
| 37 |
+
inv_vocab = {v: k for k, v in vocab.items()}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# -----------------------------
|
| 41 |
+
# Load checkpoint
|
| 42 |
+
# -----------------------------
|
| 43 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 44 |
+
|
| 45 |
+
# -----------------------------
|
| 46 |
+
# Recreate models
|
| 47 |
+
# -----------------------------
|
| 48 |
+
encoder = EncoderCNN().to(device)
|
| 49 |
+
decoder = DecoderRNN(vocab_size=checkpoint["vocab_size"]).to(device)
|
| 50 |
+
|
| 51 |
+
encoder.load_state_dict(checkpoint["encoder_state_dict"])
|
| 52 |
+
decoder.load_state_dict(checkpoint["decoder_state_dict"])
|
| 53 |
+
|
| 54 |
+
encoder.eval()
|
| 55 |
+
decoder.eval()
|
| 56 |
+
|
| 57 |
+
# -----------------------------
|
| 58 |
+
# Image preprocessing
|
| 59 |
+
# -----------------------------
|
| 60 |
+
transform = transforms.Compose([
|
| 61 |
+
transforms.Resize((224, 224)),
|
| 62 |
+
transforms.ToTensor(),
|
| 63 |
+
transforms.Normalize(
|
| 64 |
+
mean=[0.485, 0.456, 0.406],
|
| 65 |
+
std=[0.229, 0.224, 0.225]
|
| 66 |
+
)
|
| 67 |
+
])
|
| 68 |
+
|
| 69 |
+
# -----------------------------
|
| 70 |
+
# Return bundle
|
| 71 |
+
# -----------------------------
|
| 72 |
+
return {
|
| 73 |
+
"encoder": encoder,
|
| 74 |
+
"decoder": decoder,
|
| 75 |
+
"vocab": vocab,
|
| 76 |
+
"inv_vocab": inv_vocab,
|
| 77 |
+
"max_len": max_len,
|
| 78 |
+
"transform": transform
|
| 79 |
+
}
|
src/models/resnet_lstm_attention/model.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
+
|
| 9 |
+
from .loader import load_captioning_model
|
| 10 |
+
from .retrieval import RetrievalService
|
| 11 |
+
from .clip_loader import load_clip_model
|
| 12 |
+
from .captioning import CaptioningService # Not directly used, but for reference
|
| 13 |
+
from utils.interfaces import UnifiedModelInterface # Adjust path if needed
|
| 14 |
+
|
| 15 |
+
class ResNetLSTMAttentionModel(UnifiedModelInterface):
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.caption_bundle = None
|
| 18 |
+
self.retrieval_service = None
|
| 19 |
+
self.device = torch.device("cpu")
|
| 20 |
+
#self.model_repo = "skodan/resnet-lstm-attention-weights"
|
| 21 |
+
|
| 22 |
+
def load(self) -> None:
|
| 23 |
+
if self.caption_bundle is not None and self.retrieval_service is not None:
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
MODEL_REPO = "skodan/resnet-lstm-attention-weights"
|
| 27 |
+
|
| 28 |
+
files_to_download = [
|
| 29 |
+
"caption_model.pth",
|
| 30 |
+
"flickr8k_retrieval_model.pth",
|
| 31 |
+
"image_embeddings.faiss",
|
| 32 |
+
"text_embeddings.faiss",
|
| 33 |
+
"image_id_map.pkl",
|
| 34 |
+
"text_id_map.pkl",
|
| 35 |
+
"vocab.pkl"
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
downloaded_paths = {}
|
| 39 |
+
for fname in files_to_download:
|
| 40 |
+
try:
|
| 41 |
+
path = hf_hub_download(
|
| 42 |
+
repo_id=MODEL_REPO,
|
| 43 |
+
filename=fname,
|
| 44 |
+
repo_type="model",
|
| 45 |
+
)
|
| 46 |
+
downloaded_paths[fname] = path
|
| 47 |
+
except Exception as e:
|
| 48 |
+
raise RuntimeError(f"Failed to download {fname} from {MODEL_REPO}: {e}")
|
| 49 |
+
|
| 50 |
+
# Download large files from HF Hub
|
| 51 |
+
caption_pth = downloaded_paths["caption_model.pth"]
|
| 52 |
+
retrieval_pth = downloaded_paths["flickr8k_retrieval_model.pth"]
|
| 53 |
+
image_index_faiss = downloaded_paths["image_embeddings.faiss"]
|
| 54 |
+
text_index_faiss = downloaded_paths["text_embeddings.faiss"]
|
| 55 |
+
image_map_pkl = downloaded_paths["image_id_map.pkl"]
|
| 56 |
+
text_map_pkl = downloaded_paths["text_id_map.pkl"]
|
| 57 |
+
vocab_pkl = downloaded_paths["vocab.pkl"]
|
| 58 |
+
|
| 59 |
+
# Load configs (assume small, committed to repo)
|
| 60 |
+
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) # go up to project root
|
| 61 |
+
config_path = os.path.join(base_dir, "configs", "caption_config.json")
|
| 62 |
+
preprocess_cfg_path = os.path.join(base_dir, "configs", "preprocess_config.json")
|
| 63 |
+
|
| 64 |
+
with open(config_path, "r") as f:
|
| 65 |
+
caption_config = json.load(f)
|
| 66 |
+
|
| 67 |
+
with open(preprocess_cfg_path, "r") as f:
|
| 68 |
+
preprocess_cfg = json.load(f)
|
| 69 |
+
|
| 70 |
+
# Load captioning
|
| 71 |
+
self.caption_bundle = load_captioning_model(
|
| 72 |
+
model_path=caption_pth,
|
| 73 |
+
vocab_path=vocab_pkl,
|
| 74 |
+
config_path=config_path,
|
| 75 |
+
device=self.device
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Load retrieval
|
| 79 |
+
clip_model = load_clip_model(
|
| 80 |
+
model_path=retrieval_pth,
|
| 81 |
+
vocab=self.caption_bundle["vocab"],
|
| 82 |
+
device=self.device
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.retrieval_service = RetrievalService(
|
| 86 |
+
clip_model=clip_model,
|
| 87 |
+
image_index_path=image_index_faiss,
|
| 88 |
+
text_index_path=text_index_faiss,
|
| 89 |
+
image_map_path=image_map_pkl,
|
| 90 |
+
text_map_path=text_map_pkl,
|
| 91 |
+
preprocess=preprocess_cfg
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
print("Model components loaded successfully.")
|
| 95 |
+
|
| 96 |
+
@torch.no_grad()
|
| 97 |
+
def generate_caption(self, image: Image.Image) -> str:
|
| 98 |
+
encoder = self.caption_bundle["encoder"]
|
| 99 |
+
decoder = self.caption_bundle["decoder"]
|
| 100 |
+
vocab = self.caption_bundle["vocab"]
|
| 101 |
+
inv_vocab = self.caption_bundle["inv_vocab"]
|
| 102 |
+
max_len = self.caption_bundle["max_len"]
|
| 103 |
+
transform = self.caption_bundle["transform"]
|
| 104 |
+
|
| 105 |
+
image_tensor = transform(image).unsqueeze(0).to(self.device)
|
| 106 |
+
features = encoder(image_tensor)
|
| 107 |
+
tokens = decoder.generate(
|
| 108 |
+
features,
|
| 109 |
+
vocab=vocab,
|
| 110 |
+
inv_vocab=inv_vocab,
|
| 111 |
+
max_len=max_len
|
| 112 |
+
)
|
| 113 |
+
return " ".join(tokens)
|
| 114 |
+
|
| 115 |
+
def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 116 |
+
return self.retrieval_service.text_to_image(text, top_k)
|
| 117 |
+
|
| 118 |
+
def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
|
| 119 |
+
return self.retrieval_service.image_to_text(image, top_k)
|
| 120 |
+
|
| 121 |
+
def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 122 |
+
image_tensor = self.retrieval_service.image_transform(image).unsqueeze(0).to(self.device)
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
emb = self.retrieval_service.clip_model.encode_image(image_tensor).cpu().numpy()
|
| 125 |
+
emb = self.retrieval_service._normalize(emb)
|
| 126 |
+
scores, idxs = self.retrieval_service.image_index.search(emb, top_k)
|
| 127 |
+
return [
|
| 128 |
+
{"image_path": self.retrieval_service.image_id_map[i], "score": float(scores[0][j])}
|
| 129 |
+
for j, i in enumerate(idxs[0])
|
| 130 |
+
]
|
src/models/resnet_lstm_attention/ret_mod_defs.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
import torchvision.models as models
|
| 8 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 12 |
+
print(f"Using device: {device}")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ImageEncoder(nn.Module):
|
| 16 |
+
def __init__(self, embed_dim=512):
|
| 17 |
+
super().__init__()
|
| 18 |
+
resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
|
| 19 |
+
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
|
| 20 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
| 21 |
+
self.fc = nn.Linear(resnet.fc.in_features, embed_dim)
|
| 22 |
+
|
| 23 |
+
# Freeze early layers (up to layer2)
|
| 24 |
+
freeze_until = 7 # Unfreezed more layers to improve the retrieval values
|
| 25 |
+
for child in list(self.backbone.children())[:freeze_until]:
|
| 26 |
+
for p in child.parameters():
|
| 27 |
+
p.requires_grad = False
|
| 28 |
+
|
| 29 |
+
print("ImageEncoder: layers 0-6 frozen, layer3+layer4 trainable.")
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
feat = self.backbone(x)
|
| 33 |
+
feat = self.pool(feat)
|
| 34 |
+
feat = torch.flatten(feat, 1)
|
| 35 |
+
emb = self.fc(feat)
|
| 36 |
+
return F.normalize(emb, p=2, dim=1)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class TextEncoder(nn.Module):
|
| 40 |
+
def __init__(self, vocab_size, embed_dim=300, hidden_dim=512, out_dim=512):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
|
| 43 |
+
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
|
| 44 |
+
self.attn = nn.Linear(hidden_dim, 1)
|
| 45 |
+
self.fc = nn.Linear(hidden_dim, out_dim)
|
| 46 |
+
|
| 47 |
+
def forward(self, captions, lengths):
|
| 48 |
+
embedded = self.embedding(captions)
|
| 49 |
+
packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
|
| 50 |
+
lstm_out, _ = self.lstm(packed)
|
| 51 |
+
lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
|
| 52 |
+
|
| 53 |
+
attn_w = F.softmax(self.attn(lstm_out), dim=1)
|
| 54 |
+
context = torch.sum(attn_w * lstm_out, dim=1)
|
| 55 |
+
|
| 56 |
+
emb = self.fc(context)
|
| 57 |
+
return F.normalize(emb, p=2, dim=1)
|
src/models/resnet_lstm_attention/retrieval.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import faiss
|
| 2 |
+
import pickle
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
|
| 8 |
+
class RetrievalService:
|
| 9 |
+
def __init__(self, clip_model, image_index_path, text_index_path,
|
| 10 |
+
image_map_path, text_map_path, preprocess):
|
| 11 |
+
|
| 12 |
+
self.device = torch.device("cpu")
|
| 13 |
+
self.clip_model = clip_model
|
| 14 |
+
|
| 15 |
+
self.image_index = faiss.read_index(image_index_path)
|
| 16 |
+
self.text_index = faiss.read_index(text_index_path)
|
| 17 |
+
|
| 18 |
+
with open(image_map_path, "rb") as f:
|
| 19 |
+
self.image_id_map = pickle.load(f)
|
| 20 |
+
|
| 21 |
+
with open(text_map_path, "rb") as f:
|
| 22 |
+
self.text_id_map = pickle.load(f)
|
| 23 |
+
|
| 24 |
+
self.image_transform = transforms.Compose([
|
| 25 |
+
transforms.Resize((224, 224)),
|
| 26 |
+
transforms.ToTensor(),
|
| 27 |
+
transforms.Normalize(
|
| 28 |
+
mean=preprocess["mean"],
|
| 29 |
+
std=preprocess["std"]
|
| 30 |
+
)
|
| 31 |
+
])
|
| 32 |
+
|
| 33 |
+
def _normalize(self, x):
|
| 34 |
+
return x / np.linalg.norm(x, axis=1, keepdims=True)
|
| 35 |
+
|
| 36 |
+
def text_to_image(self, text, top_k=5):
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
emb = self.clip_model.encode_text(text).cpu().numpy()
|
| 39 |
+
emb = self._normalize(emb)
|
| 40 |
+
|
| 41 |
+
scores, idxs = self.image_index.search(emb, top_k)
|
| 42 |
+
return [
|
| 43 |
+
{
|
| 44 |
+
"image_path": self.image_id_map[i],
|
| 45 |
+
"score": float(scores[0][j])
|
| 46 |
+
}
|
| 47 |
+
for j, i in enumerate(idxs[0])
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
def image_to_text(self, image: Image.Image, top_k=5):
|
| 51 |
+
image = self.image_transform(image).unsqueeze(0)
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
emb = self.clip_model.encode_image(image).cpu().numpy()
|
| 54 |
+
emb = self._normalize(emb)
|
| 55 |
+
|
| 56 |
+
scores, idxs = self.text_index.search(emb, top_k)
|
| 57 |
+
results = [self.text_id_map[i] for i in idxs[0]]
|
| 58 |
+
print(f"DEBUG: Returning results: {results}")
|
| 59 |
+
return results
|
src/models/resnet_lstm_attention/schemas.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
class TextQuery(BaseModel):
|
| 5 |
+
query: str
|
| 6 |
+
top_k: int = 5
|
| 7 |
+
|
| 8 |
+
class ImageResult(BaseModel):
|
| 9 |
+
image_path: str
|
| 10 |
+
score: float
|
| 11 |
+
|
| 12 |
+
class CaptionResult(BaseModel):
|
| 13 |
+
caption: str
|
src/models/resnet_lstm_attention/utils.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# backend/utils.py
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
def simple_tokenize(text: str) -> list[str]:
|
| 5 |
+
"""
|
| 6 |
+
Same simple tokenizer used during training
|
| 7 |
+
"""
|
| 8 |
+
text = re.sub(r'[^\w\s]', '', text.lower())
|
| 9 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 10 |
+
return text.split()
|
src/{streamlit_app.py → streamlit_app_old.py}
RENAMED
|
File without changes
|
src/utils/interfaces.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
class UnifiedModelInterface(ABC):
|
| 6 |
+
@abstractmethod
|
| 7 |
+
def load(self) -> None:
|
| 8 |
+
"""Lazy load all required components (models, indices, vocab, etc.)"""
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def generate_caption(self, image: Image.Image) -> str:
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def text_to_image(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 17 |
+
"""Returns [{'image_path': str, 'score': float}, ...]"""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def image_to_text(self, image: Image.Image, top_k: int = 5) -> List[str]:
|
| 22 |
+
"""Returns list of caption strings"""
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def image_to_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
def text_to_text(self, text: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 30 |
+
raise NotImplementedError("Text-to-text not supported by this model")
|