File size: 6,996 Bytes
c34d2af
 
daa0070
c34d2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa0070
 
 
 
c34d2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa0070
c34d2af
 
daa0070
c34d2af
 
daa0070
c34d2af
 
 
 
 
daa0070
 
 
c34d2af
daa0070
 
 
 
 
 
 
 
 
c34d2af
 
daa0070
c34d2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa0070
c34d2af
 
 
 
daa0070
c34d2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa0070
 
c34d2af
 
 
 
 
 
daa0070
c34d2af
 
 
 
 
 
 
daa0070
c34d2af
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import base64
import requests
import gradio as gr
from openai import OpenAI
from duckduckgo_search import DDGS
from PIL import Image
from io import BytesIO

NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1"
MODEL_NAME = "nvidia/minimaxai/minimax-m2.7"

def encode_image_from_url(url):
    """Download and encode image from URL to base64."""
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        img = Image.open(BytesIO(response.content))
        buffered = BytesIO()
        img.save(buffered, format=img.format or "PNG")
        return base64.b64encode(buffered.getvalue()).decode("utf-8")
    except Exception as e:
        return None

def encode_image_from_file(file_obj):
    """Encode uploaded image file to base64."""
    try:
        if hasattr(file_obj, 'name') and file_obj.name:
            img = Image.open(file_obj.name)
        else:
            img = Image.open(file_obj)
        buffered = BytesIO()
        img.save(buffered, format=img.format or "PNG")
        return base64.b64encode(buffered.getvalue()).decode("utf-8")
    except Exception as e:
        return None

def get_minimax_relevance(question, image_data, client):
    """Get relevance score from MiniMax-M2.7 vision model."""
    try:
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}},
                        {"type": "text", "text": f"Question: {question}\nAnalyze this image for relevance. Respond with only a number between 0.0 and 1.0 representing how relevant this image is to the question. 1.0 = highly relevant, 0.0 = not relevant. Response must be ONLY the number, no text."}
                    ]
                }
            ],
            temperature=0.1,
            max_tokens=10
        )
        score_text = response.choices[0].message.content.strip()
        score = float(score_text)
        return min(max(score, 0.0), 1.0)
    except Exception as e:
        return 0.0

def get_duckduckgo_context(question, image_description=""):
    """Get search context from DuckDuckGo."""
    try:
        query = f"{question} {image_description}".strip()
        with DDGS() as ddgs:
            results = list(ddgs.text(query, max_results=3))
        return " ".join([r["body"] for r in results]) if results else ""
    except Exception as e:
        return ""

def calculate_combined_score(minimax_score, search_context, question):
    """Combine MiniMax score with DuckDuckGo context for final score."""
    if not search_context:
        return minimax_score
    return 0.7 * minimax_score + 0.3 * (1.0 if any(word in search_context.lower() for word in question.lower().split()) else 0.5)

def rank_images(question, images, image_urls, search_context, api_key):
    """Rank images by relevance to question."""
    if not api_key:
        return [], "Please provide NVIDIA API key in secrets (NVIDIA_API_KEY)"

    if not images and not image_urls:
        return [], "Please upload images or provide image URLs"

    if not question.strip():
        return [], "Please enter a question"

    client = OpenAI(api_key=api_key, base_url=NVIDIA_BASE_URL)

    image_data_list = []

    if images:
        for img_obj in images:
            encoded = encode_image_from_file(img_obj)
            if encoded:
                image_data_list.append(("upload", encoded))

    if image_urls:
        for url in image_urls.strip().split("\n"):
            url = url.strip()
            if url:
                encoded = encode_image_from_url(url)
                if encoded:
                    image_data_list.append(("url", encoded))

    if not image_data_list:
        return [], "No valid images could be loaded"

    ranked_images = []

    for idx, (source, image_data) in enumerate(image_data_list):
        minimax_score = get_minimax_relevance(question, image_data, client)

        search_result = ""
        if search_context:
            search_result = get_duckduckgo_context(question, f"image {idx+1}")

        final_score = calculate_combined_score(minimax_score, search_result, question)

        ranked_images.append((final_score, source, image_data))

    ranked_images.sort(key=lambda x: x[0], reverse=True)

    result_gallery = []
    for score, source, image_data in ranked_images:
        if source == "upload":
            result_gallery.append(f"data:image/png;base64,{image_data}")
        else:
            img = Image.open(BytesIO(base64.b64decode(image_data)))
            img_path = f"/tmp/ranked_image_{len(result_gallery)}.png"
            img.save(img_path)
            result_gallery.append(img_path)

    return result_gallery, None

css = """
#title { text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 1em; }
#question-input { margin-bottom: 1em; }
#image-section { margin-bottom: 1em; }
#button-row { margin-bottom: 1em; }
#error-box { color: red; margin-bottom: 1em; }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("## IMAGE RANKER", elem_id="title")

    with gr.Row():
        with gr.Column(scale=1):
            question = gr.Textbox(label="Question", placeholder="What are you looking for?", elem_id="question-input")
            api_key = gr.Textbox(label="NVIDIA API Key (or set in secrets)", type="password", visible=True)

    with gr.Column(elem_id="image-section"):
        images = gr.File(file_count="multiple", file_types=["image"], label="Upload Images (up to 5)")
        gr.Markdown("**OR**")
        image_urls = gr.Textbox(label="Image URLs (one per line)", placeholder="https://example.com/image1.png")

    with gr.Row(elem_id="button-row"):
        search_btn = gr.Button("Search Context (DuckDuckGo)", variant="secondary")
        rank_btn = gr.Button("Rank Images", variant="primary")

    error_output = gr.Textbox(label="Error", visible=False, elem_id="error-box")
    gallery = gr.Gallery(label="Ranked Results", columns=3, object_fit="contain")

    search_context_state = gr.State("")

    def search_context_handler(question):
        if not question.strip():
            return "Please enter a question first", ""
        try:
            with DDGS() as ddgs:
                results = list(ddgs.text(question, max_results=5))
            context = " | ".join([f"{r['title']}: {r['body'][:100]}" for r in results]) if results else ""
            return "", context
        except Exception as e:
            return f"Search error: {str(e)}", ""

    search_btn.click(
        fn=search_context_handler,
        inputs=[question],
        outputs=[error_output, search_context_state]
    )

    rank_btn.click(
        fn=rank_images,
        inputs=[question, images, image_urls, search_context_state, api_key],
        outputs=[gallery, error_output]
    )

demo.launch(debug=False, show_error=True)