Mark8398 commited on
Commit
dfbfc84
·
verified ·
1 Parent(s): f634c61
Files changed (1) hide show
  1. app.py +149 -151
app.py CHANGED
@@ -1,152 +1,150 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from transformers import ViTModel, AutoModel, AutoTokenizer
6
- from torchvision import transforms
7
- from datasets import load_dataset
8
- from PIL import Image
9
-
10
- # --- 1. MODEL ARCHITECTURE ---
11
- class MultiModalEngine(nn.Module):
12
- def __init__(self):
13
- super().__init__()
14
- self.image_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
15
- self.text_model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
16
- self.image_projection = nn.Linear(768, 256)
17
- self.text_projection = nn.Linear(768, 256)
18
- self.logit_scale = nn.Parameter(torch.ones([]) * 2.659)
19
-
20
- def encode_text(self, input_ids, attention_mask):
21
- text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
22
- text_embeds = self.text_projection(self.mean_pooling(text_out, attention_mask))
23
- return F.normalize(text_embeds, dim=1)
24
-
25
- def encode_image(self, images):
26
- vision_out = self.image_model(pixel_values=images)
27
- image_embeds = self.image_projection(vision_out.last_hidden_state[:, 0, :])
28
- return F.normalize(image_embeds, dim=1)
29
-
30
- def mean_pooling(self, model_output, attention_mask):
31
- token_embeddings = model_output.last_hidden_state
32
- mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
33
- return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
34
-
35
- # --- 2. LOAD RESOURCES ---
36
- print("⏳ Loading resources...")
37
- device = "cpu"
38
-
39
- # Load Model
40
- model = MultiModalEngine()
41
- model.load_state_dict(torch.load("flickr8k_best_model_r1_27.pth", map_location=device))
42
- model.eval()
43
-
44
- # Load Index
45
- image_embeddings = torch.load("flickr8k_best_index.pt", map_location=device)
46
-
47
- # Load Tokenizer & Transforms
48
- tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
49
- val_transform = transforms.Compose([
50
- transforms.Resize((224, 224)),
51
- transforms.ToTensor(),
52
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
53
- ])
54
- # Load Dataset (Standard mode to fetch result images)
55
- print("⏳ Downloading dataset (this may take a minute)...")
56
- dataset = load_dataset("tsystems/flickr8k", split="train")
57
-
58
- print("✅ Server Ready!")
59
-
60
- # --- 3. SEARCH LOGIC ---
61
- def search_text(query):
62
- inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
63
- with torch.no_grad():
64
- text_emb = model.encode_text(inputs['input_ids'], inputs['attention_mask'])
65
-
66
- scores = text_emb @ image_embeddings.T
67
- scores = scores.squeeze()
68
- values, indices = torch.topk(scores, 3)
69
-
70
- return [dataset[int(idx)]['image'] for idx in indices]
71
-
72
- def search_image(query_img):
73
- if query_img is None: return []
74
- # Ensure it's a PIL Image (Gradio handles this, but good safety)
75
- if not isinstance(query_img, Image.Image):
76
- query_img = Image.fromarray(query_img)
77
-
78
- img_tensor = val_transform(query_img).unsqueeze(0)
79
- with torch.no_grad():
80
- img_emb = model.encode_image(img_tensor)
81
-
82
- scores = img_emb @ image_embeddings.T
83
- scores = scores.squeeze()
84
- values, indices = torch.topk(scores, 3)
85
-
86
- return [dataset[int(idx)]['image'] for idx in indices]
87
-
88
- # --- 4. UI WITH EXAMPLES ---
89
- with gr.Blocks(title="Flickr8k AI Search", theme=gr.themes.Soft()) as demo:
90
- gr.Markdown("# 🔍 AI Super-Search")
91
- gr.Markdown("Search for images using **Text** OR using another **Image**.")
92
-
93
- with gr.Tabs():
94
- # --- TAB 1: TEXT SEARCH ---
95
- with gr.TabItem("Search by Text"):
96
- with gr.Row():
97
- txt_input = gr.Textbox(label="Type your query", placeholder="e.g. A dog running...")
98
- txt_btn = gr.Button("Search", variant="primary")
99
-
100
- txt_gallery = gr.Gallery(label="Top Matches", columns=3, height=300)
101
-
102
- # CLICKABLE TEXT EXAMPLES
103
- gr.Examples(
104
- examples=[
105
- ["A dog running on grass"],
106
- ["Children playing in the water"],
107
- ["A girl in a pink dress"],
108
- ["A man climbing a rock"]
109
- ],
110
- inputs=txt_input, # Clicking populates this box
111
- outputs=txt_gallery, # Result appears here
112
- fn=search_text, # Function to run
113
- run_on_click=True, # Run immediately when clicked!
114
- label="Try these examples:"
115
- )
116
-
117
- txt_btn.click(search_text, inputs=txt_input, outputs=txt_gallery)
118
-
119
- # --- TAB 2: IMAGE SEARCH ---
120
- # --- TAB 2: IMAGE SEARCH ---
121
- with gr.TabItem("Search by Image"):
122
- # 1. Define components first (but don't draw them yet)
123
- # We set render=False so we can place them visually later
124
- img_input = gr.Image(type="pil", label="Upload Source Image", sources=['upload', 'clipboard'], render=False)
125
- img_gallery = gr.Gallery(label="Similar Images", columns=3, height=300, render=False)
126
-
127
- # 2. Draw Examples FIRST (So they appear at the very top)
128
- gr.Examples(
129
- examples=[
130
- ["examples/dog.jpg"],
131
- ["examples/beach.jpg"]
132
- ],
133
- inputs=img_input,
134
- outputs=img_gallery,
135
- fn=search_image,
136
- run_on_click=True,
137
- label="Click an image to test:"
138
- )
139
-
140
- # 3. Draw Input and Button (Visually below examples)
141
- with gr.Row():
142
- img_input.render() # <--- Now we actually draw the input box
143
- img_btn = gr.Button("Find Similar", variant="primary")
144
-
145
- # 4. Draw Gallery (Visually at the bottom)
146
- img_gallery.render()
147
-
148
- # 5. Connect the Button
149
- img_btn.click(search_image, inputs=img_input, outputs=img_gallery)
150
-
151
- if __name__ == "__main__":
152
  demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import ViTModel, AutoModel, AutoTokenizer
6
+ from torchvision import transforms
7
+ from datasets import load_dataset
8
+ from PIL import Image
9
+
10
+ class MultiModalEngine(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.image_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
14
+ self.text_model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
15
+ self.image_projection = nn.Linear(768, 256)
16
+ self.text_projection = nn.Linear(768, 256)
17
+ self.logit_scale = nn.Parameter(torch.ones([]) * 2.659)
18
+
19
+ def encode_text(self, input_ids, attention_mask):
20
+ text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
21
+ text_embeds = self.text_projection(self.mean_pooling(text_out, attention_mask))
22
+ return F.normalize(text_embeds, dim=1)
23
+
24
+ def encode_image(self, images):
25
+ vision_out = self.image_model(pixel_values=images)
26
+ image_embeds = self.image_projection(vision_out.last_hidden_state[:, 0, :])
27
+ return F.normalize(image_embeds, dim=1)
28
+
29
+ def mean_pooling(self, model_output, attention_mask):
30
+ token_embeddings = model_output.last_hidden_state
31
+ mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
32
+ return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
33
+
34
+
35
+ print("⏳ Loading resources...")
36
+ device = "cpu"
37
+
38
+ # Load Model
39
+ model = MultiModalEngine()
40
+ model.load_state_dict(torch.load("flickr8k_best_model_r1_27.pth", map_location=device))
41
+ model.eval()
42
+
43
+ # Load Index
44
+ image_embeddings = torch.load("flickr8k_best_index.pt", map_location=device)
45
+
46
+ # Load Tokenizer & Transforms
47
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
48
+ val_transform = transforms.Compose([
49
+ transforms.Resize((224, 224)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
52
+ ])
53
+ # Load Dataset (Standard mode to fetch result images)
54
+ print("Downloading dataset (this may take a minute)...")
55
+ dataset = load_dataset("tsystems/flickr8k", split="train")
56
+
57
+ print("Server Ready!")
58
+
59
+
60
+ def search_text(query):
61
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
62
+ with torch.no_grad():
63
+ text_emb = model.encode_text(inputs['input_ids'], inputs['attention_mask'])
64
+
65
+ scores = text_emb @ image_embeddings.T
66
+ scores = scores.squeeze()
67
+ values, indices = torch.topk(scores, 3)
68
+
69
+ return [dataset[int(idx)]['image'] for idx in indices]
70
+
71
+ def search_image(query_img):
72
+ if query_img is None: return []
73
+ # Ensure it's a PIL Image
74
+ if not isinstance(query_img, Image.Image):
75
+ query_img = Image.fromarray(query_img)
76
+
77
+ img_tensor = val_transform(query_img).unsqueeze(0)
78
+ with torch.no_grad():
79
+ img_emb = model.encode_image(img_tensor)
80
+
81
+ scores = img_emb @ image_embeddings.T
82
+ scores = scores.squeeze()
83
+ values, indices = torch.topk(scores, 3)
84
+
85
+ return [dataset[int(idx)]['image'] for idx in indices]
86
+
87
+
88
+ with gr.Blocks(title="CLIP Sytle MultiModal Search", theme=gr.themes.Soft()) as demo:
89
+ gr.Markdown("# 🔍CLIP Sytle MultiModal")
90
+ gr.Markdown("Search for images using **Text** OR using another **Image**.")
91
+
92
+ with gr.Tabs():
93
+ # --- TAB 1: TEXT SEARCH ---
94
+ with gr.TabItem("Search by Text"):
95
+ with gr.Row():
96
+ txt_input = gr.Textbox(label="Type your query", placeholder="e.g. A dog running...")
97
+ txt_btn = gr.Button("Search", variant="primary")
98
+
99
+ txt_gallery = gr.Gallery(label="Top Matches", columns=3, height=300)
100
+
101
+ # CLICKABLE TEXT EXAMPLES
102
+ gr.Examples(
103
+ examples=[
104
+ ["A dog running on grass"],
105
+ ["Children playing in the water"],
106
+ ["A girl in a pink dress"],
107
+ ["A man climbing a rock"]
108
+ ],
109
+ inputs=txt_input, # Clicking populates this box
110
+ outputs=txt_gallery, # Result appears here
111
+ fn=search_text, # Function to run
112
+ run_on_click=True, # Run immediately when clicked!
113
+ label="Try these examples:"
114
+ )
115
+
116
+ txt_btn.click(search_text, inputs=txt_input, outputs=txt_gallery)
117
+
118
+ # --- TAB 2: IMAGE SEARCH ---
119
+ with gr.TabItem("Search by Image"):
120
+ # Define components first (but don't draw them yet)
121
+ # We set render=False so we can place them visually later
122
+ img_input = gr.Image(type="pil", label="Upload Source Image", sources=['upload', 'clipboard'], render=False)
123
+ img_gallery = gr.Gallery(label="Similar Images", columns=3, height=300, render=False)
124
+
125
+ # Draw Examples FIRST (So they appear at the very top)
126
+ gr.Examples(
127
+ examples=[
128
+ ["examples/dog.jpg"],
129
+ ["examples/beach.jpg"]
130
+ ],
131
+ inputs=img_input,
132
+ outputs=img_gallery,
133
+ fn=search_image,
134
+ run_on_click=True,
135
+ label="Click an image to test:"
136
+ )
137
+
138
+ # Draw Input and Button (Visually below examples)
139
+ with gr.Row():
140
+ img_input.render() #
141
+ img_btn = gr.Button("Find Similar", variant="primary")
142
+
143
+ # Draw Gallery (Visually at the bottom)
144
+ img_gallery.render()
145
+
146
+ # Connect the Button
147
+ img_btn.click(search_image, inputs=img_input, outputs=img_gallery)
148
+
149
+ if __name__ == "__main__":
 
 
150
  demo.launch()