Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision | |
| from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
| from torchvision import transforms | |
| from transformers import AutoProcessor, LlavaForConditionalGeneration | |
| from PIL import Image | |
| import numpy as np | |
| import faiss | |
| import pickle | |
| import os | |
| from pathlib import Path | |
| import gc | |
| # Configuration | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Set base directory for images (can be configured) | |
| IMAGE_BASE_DIR = os.getenv('IMAGE_BASE_DIR', './images') | |
| # Use LLaVA Phi-3-Mini (lightweight ~4GB model) | |
| USE_LLAVA = True | |
| LLAVA_MODEL = "xtuner/llava-phi-3-mini-hf" # Much lighter than llava-1.5-7b | |
| print(f"LLaVA Phi-3-Mini (lightweight) enabled: {USE_LLAVA}") | |
| # Load models with memory optimization | |
| def load_models(): | |
| print("Loading models with memory optimization...") | |
| print(f"Device: {device}") | |
| # 1. Load Faster R-CNN | |
| print("Loading Faster R-CNN...") | |
| rcnn_model = fasterrcnn_resnet50_fpn(pretrained=True) | |
| rcnn_model = rcnn_model.to(device) | |
| rcnn_model.eval() | |
| rcnn_backbone = rcnn_model.backbone | |
| print("β Faster R-CNN loaded") | |
| # 2. Load LLaVA Phi-3-Mini | |
| llava_model = None | |
| llava_processor = None | |
| if USE_LLAVA: | |
| print(f"Loading LLaVA Phi-3-Mini from {LLAVA_MODEL}...") | |
| # Load Processor | |
| from transformers import LlavaProcessor | |
| llava_processor = LlavaProcessor.from_pretrained(LLAVA_MODEL) | |
| # --- CRITICAL FIX START --- | |
| # Explicitly set patch_size to prevent: 'int' and 'NoneType' error | |
| if hasattr(llava_processor, 'image_processor'): | |
| llava_processor.image_processor.patch_size = 14 | |
| llava_processor.patch_size = 14 | |
| # --- CRITICAL FIX END --- | |
| # Load Model with memory-efficient settings | |
| llava_model = LlavaForConditionalGeneration.from_pretrained( | |
| LLAVA_MODEL, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True, | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| if not torch.cuda.is_available(): | |
| llava_model = llava_model.to(device) | |
| llava_model.eval() | |
| print("β LLaVA Phi-3-Mini loaded and patch_size configured") | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return rcnn_backbone, llava_model, llava_processor | |
| # Load FAISS index | |
| def load_faiss_index(): | |
| print("Loading FAISS index...") | |
| index = faiss.read_index('faiss_index.bin') | |
| with open('image_paths.pkl', 'rb') as f: | |
| image_paths = pickle.load(f) | |
| print(f"β FAISS index loaded with {index.ntotal} images") | |
| return index, image_paths | |
| # Extract features from query image (memory optimized) | |
| def extract_features(image, rcnn_backbone, llava_model, llava_processor): | |
| # Prepare image | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert('RGB') | |
| elif isinstance(image, str): | |
| image = Image.open(image).convert('RGB') | |
| # RCNN features | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| rcnn_features = rcnn_backbone(img_tensor) | |
| rcnn_feat_vector = torch.nn.functional.adaptive_avg_pool2d( | |
| rcnn_features['pool'], (1, 1) | |
| ).flatten().cpu().numpy() | |
| # LLaVA Phi-3-Mini features (FAST - direct vision encoder, no text generation) | |
| if USE_LLAVA and llava_model is not None: | |
| # CRITICAL: Ensure patch_size is set before processing | |
| if hasattr(llava_processor, 'image_processor'): | |
| llava_processor.image_processor.patch_size = 14 | |
| llava_processor.patch_size = 14 | |
| prompt = "USER: <image>\nASSISTANT:" | |
| inputs = llava_processor(text=prompt, images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Extract visual features directly (10-20x faster than generate()) | |
| # Get vision tower | |
| if hasattr(llava_model, 'get_vision_tower'): | |
| vision_tower = llava_model.get_vision_tower() | |
| elif hasattr(llava_model, 'vision_tower'): | |
| vision_tower = llava_model.vision_tower | |
| else: | |
| vision_tower = None | |
| # Use vision tower directly if available | |
| if vision_tower is not None and 'pixel_values' in inputs: | |
| image_outputs = vision_tower(inputs['pixel_values']) | |
| # Handle different output types | |
| if hasattr(image_outputs, 'pooler_output'): | |
| llava_feat_vector = image_outputs.pooler_output.squeeze().cpu().numpy() | |
| elif hasattr(image_outputs, 'last_hidden_state'): | |
| llava_feat_vector = image_outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() | |
| elif isinstance(image_outputs, tuple): | |
| llava_feat_vector = image_outputs[0].mean(dim=1).squeeze().cpu().numpy() | |
| else: | |
| if image_outputs.dim() > 2: | |
| llava_feat_vector = image_outputs.mean(dim=1).squeeze().cpu().numpy() | |
| else: | |
| llava_feat_vector = image_outputs.squeeze().cpu().numpy() | |
| else: | |
| # Fallback: use model forward pass (still much faster than generate) | |
| outputs = llava_model( | |
| input_ids=inputs['input_ids'], | |
| attention_mask=inputs.get('attention_mask'), | |
| pixel_values=inputs.get('pixel_values'), | |
| output_hidden_states=True | |
| ) | |
| llava_feat_vector = outputs.hidden_states[-1].mean(dim=1).squeeze().cpu().numpy() | |
| # Ensure proper shape | |
| if llava_feat_vector.ndim > 1: | |
| llava_feat_vector = llava_feat_vector.flatten() | |
| # Resize to 1024 dimensions | |
| if llava_feat_vector.shape[0] != 1024: | |
| if llava_feat_vector.shape[0] < 1024: | |
| llava_feat_vector = np.pad(llava_feat_vector, (0, 1024 - llava_feat_vector.shape[0])) | |
| else: | |
| llava_feat_vector = llava_feat_vector[:1024] | |
| else: | |
| # Use zeros when LLaVA is disabled (maintains compatibility) | |
| llava_feat_vector = np.zeros(1024, dtype=np.float32) | |
| # Combine features | |
| combined_features = np.concatenate([rcnn_feat_vector, llava_feat_vector]) | |
| combined_features = combined_features / np.linalg.norm(combined_features) | |
| # Clean up | |
| del img_tensor | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return combined_features.reshape(1, -1).astype('float32') | |
| # Search similar images | |
| def search_similar_images(query_image, top_k=3): | |
| if query_image is None: | |
| return [] | |
| try: | |
| # Extract features | |
| query_features = extract_features(query_image, rcnn_backbone, llava_model, llava_processor) | |
| # Search in FAISS | |
| distances, indices = faiss_index.search(query_features, int(top_k)) | |
| # Get similar images | |
| results = [] | |
| for i, (dist, idx) in enumerate(zip(distances[0], indices[0])): | |
| # Get original path and try to find the image | |
| original_path = image_paths[idx] | |
| # Try multiple path strategies | |
| img = None | |
| img_found = False | |
| # Strategy 1: Check if original path exists | |
| if os.path.exists(original_path): | |
| img = Image.open(original_path).convert('RGB') | |
| img_found = True | |
| else: | |
| # Strategy 2: Try relative path from IMAGE_BASE_DIR | |
| filename = Path(original_path).name | |
| category = Path(original_path).parent.name | |
| relative_path = os.path.join(IMAGE_BASE_DIR, category, filename) | |
| if os.path.exists(relative_path): | |
| img = Image.open(relative_path).convert('RGB') | |
| img_found = True | |
| else: | |
| # Strategy 3: Search for filename in IMAGE_BASE_DIR | |
| for root, dirs, files in os.walk(IMAGE_BASE_DIR): | |
| if filename in files: | |
| found_path = os.path.join(root, filename) | |
| img = Image.open(found_path).convert('RGB') | |
| img_found = True | |
| break | |
| if img_found and img is not None: | |
| # Add image with metadata | |
| category_name = Path(original_path).parent.name.replace('_', ' ').title() | |
| label = f"#{i+1} - {category_name}\nSimilarity: {1/(1+dist):.3f}" | |
| results.append((img, label)) | |
| else: | |
| # Create placeholder image if not found | |
| placeholder = Image.new('RGB', (224, 224), color='lightgray') | |
| label = f"#{i+1} - Image not available\n{Path(original_path).name}" | |
| results.append((placeholder, label)) | |
| return results | |
| except Exception as e: | |
| print(f"Error in search: {str(e)}") | |
| return [] | |
| # Initialize models and index | |
| print("Initializing system...") | |
| rcnn_backbone, llava_model, llava_processor = load_models() | |
| faiss_index, image_paths = load_faiss_index() | |
| print("β System ready!") | |
| # Create Gradio interface | |
| with gr.Blocks(title="CBIR - Content-Based Image Retrieval", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Content-Based Image Retrieval System") | |
| gr.Markdown(""" | |
| Upload any image to find visually similar images from the database. | |
| This system uses **Faster R-CNN** + **LLaVA 1.5** for intelligent image understanding. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="π€ Upload Your Image") | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Number of Results", | |
| info="How many similar images to retrieve" | |
| ) | |
| search_btn = gr.Button("π Search Similar Images", variant="primary", size="lg") | |
| with gr.Accordion("βΉοΈ About", open=False): | |
| llava_status = "Enabled β" if USE_LLAVA else "Disabled" | |
| gr.Markdown(f""" | |
| **System Information:** | |
| - Device: `{device}` | |
| - Indexed Images: `{faiss_index.ntotal:,}` | |
| - Feature Dimensions: `2048` (RCNN: 1024 + LLaVA: 1024) | |
| - Vision Model: Faster R-CNN ResNet-50 | |
| - Language-Vision Model: LLaVA Phi-3-Mini (lightweight ~4GB) | |
| - LLaVA Status: `{llava_status}` | |
| - Similarity Metric: L2 Distance | |
| - Memory Usage: ~6-8GB (optimized) | |
| *Using LLaVA Phi-3-Mini for efficient vision-language understanding.* | |
| """) | |
| with gr.Column(scale=2): | |
| output_gallery = gr.Gallery( | |
| label="π― Similar Images", | |
| columns=3, | |
| rows=2, | |
| height="auto", | |
| object_fit="contain", | |
| show_label=True | |
| ) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| placeholder="Upload an image and click search...", | |
| interactive=False | |
| ) | |
| # Example images section | |
| example_dir = os.path.join(IMAGE_BASE_DIR, "examples") | |
| if os.path.exists(example_dir): | |
| example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir) | |
| if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:5] | |
| if example_images: | |
| gr.Markdown("### π‘ Try These Examples:") | |
| gr.Examples( | |
| examples=[[img] for img in example_images], | |
| inputs=input_image | |
| ) | |
| def search_with_status(image, k): | |
| if image is None: | |
| return [], "β οΈ Please upload an image first" | |
| results = search_similar_images(image, k) | |
| if results: | |
| return results, f"β Found {len(results)} similar images" | |
| else: | |
| return [], "β No results found or error occurred" | |
| search_btn.click( | |
| fn=search_with_status, | |
| inputs=[input_image, top_k_slider], | |
| outputs=[output_gallery, status_text] | |
| ) | |
| # Launch app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", # Allow external connections | |
| server_port=7860, | |
| share=True # Create public Gradio link | |
| ) | |