Mark8398's picture
init
dfbfc84 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel, AutoModel, AutoTokenizer
from torchvision import transforms
from datasets import load_dataset
from PIL import Image
class MultiModalEngine(nn.Module):
def __init__(self):
super().__init__()
self.image_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
self.text_model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
self.image_projection = nn.Linear(768, 256)
self.text_projection = nn.Linear(768, 256)
self.logit_scale = nn.Parameter(torch.ones([]) * 2.659)
def encode_text(self, input_ids, attention_mask):
text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
text_embeds = self.text_projection(self.mean_pooling(text_out, attention_mask))
return F.normalize(text_embeds, dim=1)
def encode_image(self, images):
vision_out = self.image_model(pixel_values=images)
image_embeds = self.image_projection(vision_out.last_hidden_state[:, 0, :])
return F.normalize(image_embeds, dim=1)
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
print("⏳ Loading resources...")
device = "cpu"
# Load Model
model = MultiModalEngine()
model.load_state_dict(torch.load("flickr8k_best_model_r1_27.pth", map_location=device))
model.eval()
# Load Index
image_embeddings = torch.load("flickr8k_best_index.pt", map_location=device)
# Load Tokenizer & Transforms
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load Dataset (Standard mode to fetch result images)
print("Downloading dataset (this may take a minute)...")
dataset = load_dataset("tsystems/flickr8k", split="train")
print("Server Ready!")
def search_text(query):
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
text_emb = model.encode_text(inputs['input_ids'], inputs['attention_mask'])
scores = text_emb @ image_embeddings.T
scores = scores.squeeze()
values, indices = torch.topk(scores, 3)
return [dataset[int(idx)]['image'] for idx in indices]
def search_image(query_img):
if query_img is None: return []
# Ensure it's a PIL Image
if not isinstance(query_img, Image.Image):
query_img = Image.fromarray(query_img)
img_tensor = val_transform(query_img).unsqueeze(0)
with torch.no_grad():
img_emb = model.encode_image(img_tensor)
scores = img_emb @ image_embeddings.T
scores = scores.squeeze()
values, indices = torch.topk(scores, 3)
return [dataset[int(idx)]['image'] for idx in indices]
with gr.Blocks(title="CLIP Sytle MultiModal Search", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔍CLIP Sytle MultiModal")
gr.Markdown("Search for images using **Text** OR using another **Image**.")
with gr.Tabs():
# --- TAB 1: TEXT SEARCH ---
with gr.TabItem("Search by Text"):
with gr.Row():
txt_input = gr.Textbox(label="Type your query", placeholder="e.g. A dog running...")
txt_btn = gr.Button("Search", variant="primary")
txt_gallery = gr.Gallery(label="Top Matches", columns=3, height=300)
# CLICKABLE TEXT EXAMPLES
gr.Examples(
examples=[
["A dog running on grass"],
["Children playing in the water"],
["A girl in a pink dress"],
["A man climbing a rock"]
],
inputs=txt_input, # Clicking populates this box
outputs=txt_gallery, # Result appears here
fn=search_text, # Function to run
run_on_click=True, # Run immediately when clicked!
label="Try these examples:"
)
txt_btn.click(search_text, inputs=txt_input, outputs=txt_gallery)
# --- TAB 2: IMAGE SEARCH ---
with gr.TabItem("Search by Image"):
# Define components first (but don't draw them yet)
# We set render=False so we can place them visually later
img_input = gr.Image(type="pil", label="Upload Source Image", sources=['upload', 'clipboard'], render=False)
img_gallery = gr.Gallery(label="Similar Images", columns=3, height=300, render=False)
# Draw Examples FIRST (So they appear at the very top)
gr.Examples(
examples=[
["examples/dog.jpg"],
["examples/beach.jpg"]
],
inputs=img_input,
outputs=img_gallery,
fn=search_image,
run_on_click=True,
label="Click an image to test:"
)
# Draw Input and Button (Visually below examples)
with gr.Row():
img_input.render() #
img_btn = gr.Button("Find Similar", variant="primary")
# Draw Gallery (Visually at the bottom)
img_gallery.render()
# Connect the Button
img_btn.click(search_image, inputs=img_input, outputs=img_gallery)
if __name__ == "__main__":
demo.launch()