Spaces:
Runtime error
Runtime error
| import app as gr | |
| import torch | |
| from PIL import Image | |
| import chromadb | |
| from scripts.qwen3_vl_embedding import Qwen3VLEmbedder | |
| from scripts.qwen3_vl_reranker import Qwen3VLReranker | |
| # Configuration | |
| VDB_PATH = "./VDB" # ChromaDB will be in the Space | |
| TERM_COUNT = 10 | |
| # Load ChromaDB | |
| print("Loading ChromaDB collection...") | |
| chroma_client = chromadb.PersistentClient(path=VDB_PATH) | |
| collection = chroma_client.get_collection(name="aat_terms") | |
| # Load models | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading Qwen embedding model on {device}...") | |
| embedding_model = Qwen3VLEmbedder( | |
| model_name_or_path="Qwen/Qwen3-VL-Embedding-2B", | |
| dtype=torch.float16 if device == "cuda" else torch.float32, | |
| ) | |
| print(f"Loading Qwen reranking model on {device}...") | |
| reranking_model = Qwen3VLReranker( | |
| model_name_or_path="Qwen/Qwen3-VL-Reranker-2B", dtype=torch.float16 | |
| ) | |
| print("Models loaded successfully!") | |
| def process_multiple_images(images, state): | |
| """Process multiple uploaded images""" | |
| state = {"all_results": {}, "current_index": 0} | |
| if not images or len(images) == 0: | |
| return None, "No images uploaded", gr.update(choices=[], value=[]), "", state | |
| print(f"Processing {len(images)} images...") | |
| for idx, image_path in enumerate(images): | |
| try: | |
| image = Image.open(image_path) | |
| # Generate keywords for this image | |
| art_query = {"image": image, "text": ""} | |
| query_input = [art_query] | |
| # Generate embeddings | |
| image_features = embedding_model.process(query_input) | |
| image_features = torch.nn.functional.normalize(image_features, p=2, dim=1) | |
| # Query ChromaDB | |
| results = collection.query( | |
| query_embeddings=image_features.cpu().float().tolist(), n_results=10 | |
| ) | |
| # Collect results | |
| labels = [] | |
| input_docs = [] | |
| if results["documents"]: | |
| for doc, metadatas in zip( | |
| results["documents"][0], results["metadatas"][0] | |
| ): | |
| input_docs.append({"text": doc}) | |
| labels.append(metadatas["term_label"]) | |
| # Rerank | |
| rerank_inputs = { | |
| "instruction": "Retrieve Art & Architecture Thesaurus terms relevant to the given image.", | |
| "query": art_query, | |
| "documents": input_docs, | |
| "fps": 1.0, | |
| } | |
| scores = reranking_model.process(rerank_inputs) | |
| sorted_results = sorted(zip(scores, labels), reverse=True) | |
| # Store results with all keywords selected by default | |
| state["all_results"][idx] = { | |
| "image": image, | |
| "keywords": [label for _, label in sorted_results], | |
| "scores": [score for score, _ in sorted_results], | |
| "selected": [True] * len(sorted_results), | |
| } | |
| print(f"Processed image {idx + 1}/{len(images)}") | |
| except Exception as e: | |
| print(f"Error processing image {idx}: {e}") | |
| state["all_results"][idx] = { | |
| "image": ( | |
| Image.open(image_path) | |
| if isinstance(image_path, str) | |
| else image_path | |
| ), | |
| "keywords": [], | |
| "scores": [], | |
| "selected": [], | |
| } | |
| # Show first image | |
| img, status, checkbox_update = show_image(0, state) | |
| return img, status, checkbox_update, "", state | |
| def show_image(index, state): | |
| """Display a specific image and its keywords""" | |
| all_results = state["all_results"] | |
| if index not in all_results: | |
| return None, f"No image at index {index}", gr.update(choices=[], value=[]) | |
| state["current_index"] = index | |
| result = all_results[index] | |
| # Create display strings for keywords | |
| keyword_choices = [] | |
| for kw, score in zip(result["keywords"], result["scores"]): | |
| keyword_choices.append(f"{kw} ({score * 100:.1f}%)") | |
| # Get currently selected keywords | |
| selected_keywords = [ | |
| keyword_choices[i] for i, sel in enumerate(result["selected"]) if sel | |
| ] | |
| status = f"Image {index + 1} of {len(all_results)}" | |
| return ( | |
| result["image"], | |
| status, | |
| gr.update(choices=keyword_choices, value=selected_keywords), | |
| ) | |
| def update_selections(selected_keywords, state): | |
| """Update which keywords are selected for current image""" | |
| all_results = state["all_results"] | |
| current_index = state["current_index"] | |
| if current_index not in all_results: | |
| return state | |
| result = all_results[current_index] | |
| for i in range(len(result["keywords"])): | |
| keyword_display = f"{result['keywords'][i]} ({result['scores'][i] * 100:.1f}%)" | |
| result["selected"][i] = keyword_display in selected_keywords | |
| return state | |
| def next_image(state): | |
| """Go to next image""" | |
| current_index = state["current_index"] | |
| next_idx = current_index + 1 | |
| if next_idx < len(state["all_results"]): | |
| img, status, cb = show_image(next_idx, state) | |
| else: | |
| img, status, cb = show_image(current_index, state) | |
| return img, status, cb, state | |
| def previous_image(state): | |
| """Go to previous image""" | |
| current_index = state["current_index"] | |
| prev_idx = current_index - 1 | |
| if prev_idx >= 0: | |
| img, status, cb = show_image(prev_idx, state) | |
| else: | |
| img, status, cb = show_image(current_index, state) | |
| return img, status, cb, state | |
| def export_results(state): | |
| """Export final keywords for all images""" | |
| all_results = state["all_results"] | |
| output = [] | |
| for idx in sorted(all_results.keys()): | |
| result = all_results[idx] | |
| selected_kw = [ | |
| result["keywords"][i] for i, sel in enumerate(result["selected"]) if sel | |
| ] | |
| output.append(f"Image {idx + 1}: {', '.join(selected_kw)}") | |
| return "\n\n".join(output) | |
| # Create Gradio interface | |
| with gr.Blocks() as interface: | |
| gr.Markdown("# MCAM Art Keyword Generator") | |
| gr.Markdown( | |
| "Upload multiple images, review keywords for each, and export selected keywords" | |
| ) | |
| # Session state | |
| state = gr.State({"all_results": {}, "current_index": 0}) | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_input = gr.File( | |
| file_count="multiple", | |
| file_types=["image"], | |
| label="Upload Images (multiple files)", | |
| ) | |
| process_btn = gr.Button("Process All Images", variant="primary") | |
| with gr.Column(): | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| current_image = gr.Image(label="Current Image", type="pil") | |
| with gr.Row(): | |
| prev_btn = gr.Button("← Previous") | |
| next_btn = gr.Button("Next →") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Select Keywords to Keep") | |
| keyword_checkboxes = gr.CheckboxGroup( | |
| choices=[], | |
| label="Keywords (check to keep, uncheck to remove)", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| export_btn = gr.Button("Export Selected Keywords", variant="primary") | |
| export_output = gr.Textbox(label="Final Keywords for All Images", lines=10) | |
| # Wire up the interface | |
| process_btn.click( | |
| fn=process_multiple_images, | |
| inputs=[upload_input, state], | |
| outputs=[current_image, status_text, keyword_checkboxes, export_output, state], | |
| ) | |
| keyword_checkboxes.change( | |
| fn=update_selections, | |
| inputs=[keyword_checkboxes, state], | |
| outputs=[state], | |
| ) | |
| next_btn.click( | |
| fn=next_image, | |
| inputs=[state], | |
| outputs=[current_image, status_text, keyword_checkboxes, state], | |
| ) | |
| prev_btn.click( | |
| fn=previous_image, | |
| inputs=[state], | |
| outputs=[current_image, status_text, keyword_checkboxes, state], | |
| ) | |
| export_btn.click( | |
| fn=export_results, | |
| inputs=[state], | |
| outputs=[export_output], | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |