micole66 commited on
Commit
0161cc6
·
verified ·
1 Parent(s): 1c4f001

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import gradio as gr
4
+ from openai import OpenAI
5
+ from duckduckgo_search import DDGS
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1"
10
+ MODEL_NAME = "nvidia/minimaxai/minimax-m2.7"
11
+
12
+ def encode_image_from_url(url):
13
+ """Download and encode image from URL to base64."""
14
+ try:
15
+ response = requests.get(url, timeout=10)
16
+ response.raise_for_status()
17
+ img = Image.open(BytesIO(response.content))
18
+ buffered = BytesIO()
19
+ img.save(buffered, format=img.format or "PNG")
20
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
21
+ except Exception as e:
22
+ return None
23
+
24
+ def encode_image_from_file(file_obj):
25
+ """Encode uploaded image file to base64."""
26
+ try:
27
+ img = Image.open(file_obj.name)
28
+ buffered = BytesIO()
29
+ img.save(buffered, format=img.format or "PNG")
30
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
31
+ except Exception as e:
32
+ return None
33
+
34
+ def get_minimax_relevance(question, image_data, client):
35
+ """Get relevance score from MiniMax-M2.7 vision model."""
36
+ try:
37
+ response = client.chat.completions.create(
38
+ model=MODEL_NAME,
39
+ messages=[
40
+ {
41
+ "role": "user",
42
+ "content": [
43
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}},
44
+ {"type": "text", "text": f"Question: {question}\nAnalyze this image for relevance. Respond with only a number between 0.0 and 1.0 representing how relevant this image is to the question. 1.0 = highly relevant, 0.0 = not relevant. Response must be ONLY the number, no text."}
45
+ ]
46
+ }
47
+ ],
48
+ temperature=0.1,
49
+ max_tokens=10
50
+ )
51
+ score_text = response.choices[0].message.content.strip()
52
+ score = float(score_text)
53
+ return min(max(score, 0.0), 1.0)
54
+ except Exception as e:
55
+ return 0.0
56
+
57
+ def get_duckduckgo_context(question, image_description=""):
58
+ """Get search context from DuckDuckGo."""
59
+ try:
60
+ query = f"{question} {image_description}".strip()
61
+ with DDGS() as ddgs:
62
+ results = list(ddgs.text(query, max_results=3))
63
+ return " ".join([r["body"] for r in results]) if results else ""
64
+ except Exception as e:
65
+ return ""
66
+
67
+ def calculate_combined_score(minimax_score, search_context, question):
68
+ """Combine MiniMax score with DuckDuckGo context for final score."""
69
+ if not search_context:
70
+ return minimax_score
71
+ return 0.7 * minimax_score + 0.3 * (1.0 if any(word in search_context.lower() for word in question.lower().split()) else 0.5)
72
+
73
+ def rank_images(question, images, image_urls, search_context, api_key):
74
+ """Rank images by relevance to question."""
75
+ if not api_key:
76
+ return None, "Please provide NVIDIA API key in secrets (NVIDIA_API_KEY)"
77
+
78
+ if not images and not image_urls:
79
+ return None, "Please upload images or provide image URLs"
80
+
81
+ if not question.strip():
82
+ return None, "Please enter a question"
83
+
84
+ client = OpenAI(api_key=api_key, base_url=NVIDIA_BASE_URL)
85
+
86
+ image_data_list = []
87
+
88
+ for img_obj in images:
89
+ encoded = encode_image_from_file(img_obj)
90
+ if encoded:
91
+ image_data_list.append(("upload", encoded))
92
+
93
+ for url in image_urls.strip().split("\n"):
94
+ url = url.strip()
95
+ if url:
96
+ encoded = encode_image_from_url(url)
97
+ if encoded:
98
+ image_data_list.append(("url", encoded))
99
+
100
+ if not image_data_list:
101
+ return None, "No valid images could be loaded"
102
+
103
+ ranked_images = []
104
+
105
+ for idx, (source, image_data) in enumerate(image_data_list):
106
+ minimax_score = get_minimax_relevance(question, image_data, client)
107
+
108
+ search_result = ""
109
+ if search_context:
110
+ search_result = get_duckduckgo_context(question, f"image {idx+1}")
111
+
112
+ final_score = calculate_combined_score(minimax_score, search_result, question)
113
+
114
+ ranked_images.append((final_score, source, image_data))
115
+
116
+ ranked_images.sort(key=lambda x: x[0], reverse=True)
117
+
118
+ result_gallery = []
119
+ for score, source, image_data in ranked_images:
120
+ if source == "upload":
121
+ result_gallery.append((f"data:image/png;base64,{image_data}",))
122
+ else:
123
+ img = Image.open(BytesIO(base64.b64decode(image_data)))
124
+ img_path = f"/tmp/ranked_image_{len(result_gallery)}.png"
125
+ img.save(img_path)
126
+ result_gallery.append((img_path,))
127
+
128
+ return result_gallery, None
129
+
130
+ css = """
131
+ #title { text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 1em; }
132
+ #question-input { margin-bottom: 1em; }
133
+ #image-section { margin-bottom: 1em; }
134
+ #button-row { margin-bottom: 1em; }
135
+ #error-box { color: red; margin-bottom: 1em; }
136
+ """
137
+
138
+ with gr.Blocks(css=css) as demo:
139
+ gr.Markdown("## IMAGE RANKER", elem_id="title")
140
+
141
+ with gr.Row():
142
+ with gr.Column(scale=1):
143
+ question = gr.Textbox(label="Question", placeholder="What are you looking for?", elem_id="question-input")
144
+ api_key = gr.Textbox(label="NVIDIA API Key (or set in secrets)", type="password", visible=True)
145
+
146
+ with gr.Column(elem_id="image-section"):
147
+ images = gr.File(file_count="multiple", file_types=["image"], label="Upload Images (up to 5)")
148
+ gr.Markdown("**OR**")
149
+ image_urls = gr.Textbox(label="Image URLs (one per line)", placeholder="https://example.com/image1.png")
150
+
151
+ with gr.Row(elem_id="button-row"):
152
+ search_btn = gr.Button("Search Context (DuckDuckGo)", variant="secondary")
153
+ rank_btn = gr.Button("Rank Images", variant="primary")
154
+
155
+ error_output = gr.Textbox(label="Error", visible=False, elem_id="error-box")
156
+ gallery = gr.Gallery(label="Ranked Results", columns=3, object_fit="contain")
157
+
158
+ def search_context_handler(question):
159
+ if not question.strip():
160
+ return "Please enter a question first", ""
161
+ try:
162
+ with DDGS() as ddgs:
163
+ results = list(ddgs.text(question, max_results=5))
164
+ context = " | ".join([f"{r['title']}: {r['body'][:100]}" for r in results]) if results else "No results"
165
+ return "", context
166
+ except Exception as e:
167
+ return f"Search error: {str(e)}", ""
168
+
169
+ search_btn.click(
170
+ fn=search_context_handler,
171
+ inputs=[question],
172
+ outputs=[error_output, gr.Textbox(visible=False)]
173
+ )
174
+
175
+ search_context_state = gr.State("")
176
+
177
+ rank_btn.click(
178
+ fn=rank_images,
179
+ inputs=[question, images, image_urls, search_context_state, api_key],
180
+ outputs=[gallery, error_output]
181
+ )
182
+
183
+ demo.launch(debug=False, show_error=True)