Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import requests | |
| import gradio as gr | |
| from openai import OpenAI | |
| from duckduckgo_search import DDGS | |
| from PIL import Image | |
| from io import BytesIO | |
| NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1" | |
| MODEL_NAME = "nvidia/minimaxai/minimax-m2.7" | |
| def encode_image_from_url(url): | |
| """Download and encode image from URL to base64.""" | |
| try: | |
| response = requests.get(url, timeout=10) | |
| response.raise_for_status() | |
| img = Image.open(BytesIO(response.content)) | |
| buffered = BytesIO() | |
| img.save(buffered, format=img.format or "PNG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| except Exception as e: | |
| return None | |
| def encode_image_from_file(file_obj): | |
| """Encode uploaded image file to base64.""" | |
| try: | |
| if hasattr(file_obj, 'name') and file_obj.name: | |
| img = Image.open(file_obj.name) | |
| else: | |
| img = Image.open(file_obj) | |
| buffered = BytesIO() | |
| img.save(buffered, format=img.format or "PNG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| except Exception as e: | |
| return None | |
| def get_minimax_relevance(question, image_data, client): | |
| """Get relevance score from MiniMax-M2.7 vision model.""" | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}}, | |
| {"type": "text", "text": f"Question: {question}\nAnalyze this image for relevance. Respond with only a number between 0.0 and 1.0 representing how relevant this image is to the question. 1.0 = highly relevant, 0.0 = not relevant. Response must be ONLY the number, no text."} | |
| ] | |
| } | |
| ], | |
| temperature=0.1, | |
| max_tokens=10 | |
| ) | |
| score_text = response.choices[0].message.content.strip() | |
| score = float(score_text) | |
| return min(max(score, 0.0), 1.0) | |
| except Exception as e: | |
| return 0.0 | |
| def get_duckduckgo_context(question, image_description=""): | |
| """Get search context from DuckDuckGo.""" | |
| try: | |
| query = f"{question} {image_description}".strip() | |
| with DDGS() as ddgs: | |
| results = list(ddgs.text(query, max_results=3)) | |
| return " ".join([r["body"] for r in results]) if results else "" | |
| except Exception as e: | |
| return "" | |
| def calculate_combined_score(minimax_score, search_context, question): | |
| """Combine MiniMax score with DuckDuckGo context for final score.""" | |
| if not search_context: | |
| return minimax_score | |
| return 0.7 * minimax_score + 0.3 * (1.0 if any(word in search_context.lower() for word in question.lower().split()) else 0.5) | |
| def rank_images(question, images, image_urls, search_context, api_key): | |
| """Rank images by relevance to question.""" | |
| if not api_key: | |
| return [], "Please provide NVIDIA API key in secrets (NVIDIA_API_KEY)" | |
| if not images and not image_urls: | |
| return [], "Please upload images or provide image URLs" | |
| if not question.strip(): | |
| return [], "Please enter a question" | |
| client = OpenAI(api_key=api_key, base_url=NVIDIA_BASE_URL) | |
| image_data_list = [] | |
| if images: | |
| for img_obj in images: | |
| encoded = encode_image_from_file(img_obj) | |
| if encoded: | |
| image_data_list.append(("upload", encoded)) | |
| if image_urls: | |
| for url in image_urls.strip().split("\n"): | |
| url = url.strip() | |
| if url: | |
| encoded = encode_image_from_url(url) | |
| if encoded: | |
| image_data_list.append(("url", encoded)) | |
| if not image_data_list: | |
| return [], "No valid images could be loaded" | |
| ranked_images = [] | |
| for idx, (source, image_data) in enumerate(image_data_list): | |
| minimax_score = get_minimax_relevance(question, image_data, client) | |
| search_result = "" | |
| if search_context: | |
| search_result = get_duckduckgo_context(question, f"image {idx+1}") | |
| final_score = calculate_combined_score(minimax_score, search_result, question) | |
| ranked_images.append((final_score, source, image_data)) | |
| ranked_images.sort(key=lambda x: x[0], reverse=True) | |
| result_gallery = [] | |
| for score, source, image_data in ranked_images: | |
| if source == "upload": | |
| result_gallery.append(f"data:image/png;base64,{image_data}") | |
| else: | |
| img = Image.open(BytesIO(base64.b64decode(image_data))) | |
| img_path = f"/tmp/ranked_image_{len(result_gallery)}.png" | |
| img.save(img_path) | |
| result_gallery.append(img_path) | |
| return result_gallery, None | |
| css = """ | |
| #title { text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 1em; } | |
| #question-input { margin-bottom: 1em; } | |
| #image-section { margin-bottom: 1em; } | |
| #button-row { margin-bottom: 1em; } | |
| #error-box { color: red; margin-bottom: 1em; } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("## IMAGE RANKER", elem_id="title") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| question = gr.Textbox(label="Question", placeholder="What are you looking for?", elem_id="question-input") | |
| api_key = gr.Textbox(label="NVIDIA API Key (or set in secrets)", type="password", visible=True) | |
| with gr.Column(elem_id="image-section"): | |
| images = gr.File(file_count="multiple", file_types=["image"], label="Upload Images (up to 5)") | |
| gr.Markdown("**OR**") | |
| image_urls = gr.Textbox(label="Image URLs (one per line)", placeholder="https://example.com/image1.png") | |
| with gr.Row(elem_id="button-row"): | |
| search_btn = gr.Button("Search Context (DuckDuckGo)", variant="secondary") | |
| rank_btn = gr.Button("Rank Images", variant="primary") | |
| error_output = gr.Textbox(label="Error", visible=False, elem_id="error-box") | |
| gallery = gr.Gallery(label="Ranked Results", columns=3, object_fit="contain") | |
| search_context_state = gr.State("") | |
| def search_context_handler(question): | |
| if not question.strip(): | |
| return "Please enter a question first", "" | |
| try: | |
| with DDGS() as ddgs: | |
| results = list(ddgs.text(question, max_results=5)) | |
| context = " | ".join([f"{r['title']}: {r['body'][:100]}" for r in results]) if results else "" | |
| return "", context | |
| except Exception as e: | |
| return f"Search error: {str(e)}", "" | |
| search_btn.click( | |
| fn=search_context_handler, | |
| inputs=[question], | |
| outputs=[error_output, search_context_state] | |
| ) | |
| rank_btn.click( | |
| fn=rank_images, | |
| inputs=[question, images, image_urls, search_context_state, api_key], | |
| outputs=[gallery, error_output] | |
| ) | |
| demo.launch(debug=False, show_error=True) |