CBIR-System / app.py
IT4CHI2311's picture
Made changes
09421ab
import gradio as gr
import torch
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision import transforms
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import numpy as np
import faiss
import pickle
import os
from pathlib import Path
import gc
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Set base directory for images (can be configured)
IMAGE_BASE_DIR = os.getenv('IMAGE_BASE_DIR', './images')
# Use LLaVA Phi-3-Mini (lightweight ~4GB model)
USE_LLAVA = True
LLAVA_MODEL = "xtuner/llava-phi-3-mini-hf" # Much lighter than llava-1.5-7b
print(f"LLaVA Phi-3-Mini (lightweight) enabled: {USE_LLAVA}")
# Load models with memory optimization
@torch.no_grad()
def load_models():
print("Loading models with memory optimization...")
print(f"Device: {device}")
# 1. Load Faster R-CNN
print("Loading Faster R-CNN...")
rcnn_model = fasterrcnn_resnet50_fpn(pretrained=True)
rcnn_model = rcnn_model.to(device)
rcnn_model.eval()
rcnn_backbone = rcnn_model.backbone
print("βœ“ Faster R-CNN loaded")
# 2. Load LLaVA Phi-3-Mini
llava_model = None
llava_processor = None
if USE_LLAVA:
print(f"Loading LLaVA Phi-3-Mini from {LLAVA_MODEL}...")
# Load Processor
from transformers import LlavaProcessor
llava_processor = LlavaProcessor.from_pretrained(LLAVA_MODEL)
# --- CRITICAL FIX START ---
# Explicitly set patch_size to prevent: 'int' and 'NoneType' error
if hasattr(llava_processor, 'image_processor'):
llava_processor.image_processor.patch_size = 14
llava_processor.patch_size = 14
# --- CRITICAL FIX END ---
# Load Model with memory-efficient settings
llava_model = LlavaForConditionalGeneration.from_pretrained(
LLAVA_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
device_map="auto" if torch.cuda.is_available() else None
)
if not torch.cuda.is_available():
llava_model = llava_model.to(device)
llava_model.eval()
print("βœ“ LLaVA Phi-3-Mini loaded and patch_size configured")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return rcnn_backbone, llava_model, llava_processor
# Load FAISS index
def load_faiss_index():
print("Loading FAISS index...")
index = faiss.read_index('faiss_index.bin')
with open('image_paths.pkl', 'rb') as f:
image_paths = pickle.load(f)
print(f"βœ“ FAISS index loaded with {index.ntotal} images")
return index, image_paths
# Extract features from query image (memory optimized)
@torch.no_grad()
def extract_features(image, rcnn_backbone, llava_model, llava_processor):
# Prepare image
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
elif isinstance(image, str):
image = Image.open(image).convert('RGB')
# RCNN features
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img_tensor = transform(image).unsqueeze(0).to(device)
rcnn_features = rcnn_backbone(img_tensor)
rcnn_feat_vector = torch.nn.functional.adaptive_avg_pool2d(
rcnn_features['pool'], (1, 1)
).flatten().cpu().numpy()
# LLaVA Phi-3-Mini features (FAST - direct vision encoder, no text generation)
if USE_LLAVA and llava_model is not None:
# CRITICAL: Ensure patch_size is set before processing
if hasattr(llava_processor, 'image_processor'):
llava_processor.image_processor.patch_size = 14
llava_processor.patch_size = 14
prompt = "USER: <image>\nASSISTANT:"
inputs = llava_processor(text=prompt, images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Extract visual features directly (10-20x faster than generate())
# Get vision tower
if hasattr(llava_model, 'get_vision_tower'):
vision_tower = llava_model.get_vision_tower()
elif hasattr(llava_model, 'vision_tower'):
vision_tower = llava_model.vision_tower
else:
vision_tower = None
# Use vision tower directly if available
if vision_tower is not None and 'pixel_values' in inputs:
image_outputs = vision_tower(inputs['pixel_values'])
# Handle different output types
if hasattr(image_outputs, 'pooler_output'):
llava_feat_vector = image_outputs.pooler_output.squeeze().cpu().numpy()
elif hasattr(image_outputs, 'last_hidden_state'):
llava_feat_vector = image_outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
elif isinstance(image_outputs, tuple):
llava_feat_vector = image_outputs[0].mean(dim=1).squeeze().cpu().numpy()
else:
if image_outputs.dim() > 2:
llava_feat_vector = image_outputs.mean(dim=1).squeeze().cpu().numpy()
else:
llava_feat_vector = image_outputs.squeeze().cpu().numpy()
else:
# Fallback: use model forward pass (still much faster than generate)
outputs = llava_model(
input_ids=inputs['input_ids'],
attention_mask=inputs.get('attention_mask'),
pixel_values=inputs.get('pixel_values'),
output_hidden_states=True
)
llava_feat_vector = outputs.hidden_states[-1].mean(dim=1).squeeze().cpu().numpy()
# Ensure proper shape
if llava_feat_vector.ndim > 1:
llava_feat_vector = llava_feat_vector.flatten()
# Resize to 1024 dimensions
if llava_feat_vector.shape[0] != 1024:
if llava_feat_vector.shape[0] < 1024:
llava_feat_vector = np.pad(llava_feat_vector, (0, 1024 - llava_feat_vector.shape[0]))
else:
llava_feat_vector = llava_feat_vector[:1024]
else:
# Use zeros when LLaVA is disabled (maintains compatibility)
llava_feat_vector = np.zeros(1024, dtype=np.float32)
# Combine features
combined_features = np.concatenate([rcnn_feat_vector, llava_feat_vector])
combined_features = combined_features / np.linalg.norm(combined_features)
# Clean up
del img_tensor
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return combined_features.reshape(1, -1).astype('float32')
# Search similar images
def search_similar_images(query_image, top_k=3):
if query_image is None:
return []
try:
# Extract features
query_features = extract_features(query_image, rcnn_backbone, llava_model, llava_processor)
# Search in FAISS
distances, indices = faiss_index.search(query_features, int(top_k))
# Get similar images
results = []
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])):
# Get original path and try to find the image
original_path = image_paths[idx]
# Try multiple path strategies
img = None
img_found = False
# Strategy 1: Check if original path exists
if os.path.exists(original_path):
img = Image.open(original_path).convert('RGB')
img_found = True
else:
# Strategy 2: Try relative path from IMAGE_BASE_DIR
filename = Path(original_path).name
category = Path(original_path).parent.name
relative_path = os.path.join(IMAGE_BASE_DIR, category, filename)
if os.path.exists(relative_path):
img = Image.open(relative_path).convert('RGB')
img_found = True
else:
# Strategy 3: Search for filename in IMAGE_BASE_DIR
for root, dirs, files in os.walk(IMAGE_BASE_DIR):
if filename in files:
found_path = os.path.join(root, filename)
img = Image.open(found_path).convert('RGB')
img_found = True
break
if img_found and img is not None:
# Add image with metadata
category_name = Path(original_path).parent.name.replace('_', ' ').title()
label = f"#{i+1} - {category_name}\nSimilarity: {1/(1+dist):.3f}"
results.append((img, label))
else:
# Create placeholder image if not found
placeholder = Image.new('RGB', (224, 224), color='lightgray')
label = f"#{i+1} - Image not available\n{Path(original_path).name}"
results.append((placeholder, label))
return results
except Exception as e:
print(f"Error in search: {str(e)}")
return []
# Initialize models and index
print("Initializing system...")
rcnn_backbone, llava_model, llava_processor = load_models()
faiss_index, image_paths = load_faiss_index()
print("βœ“ System ready!")
# Create Gradio interface
with gr.Blocks(title="CBIR - Content-Based Image Retrieval", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ” Content-Based Image Retrieval System")
gr.Markdown("""
Upload any image to find visually similar images from the database.
This system uses **Faster R-CNN** + **LLaVA 1.5** for intelligent image understanding.
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="πŸ“€ Upload Your Image")
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of Results",
info="How many similar images to retrieve"
)
search_btn = gr.Button("πŸ” Search Similar Images", variant="primary", size="lg")
with gr.Accordion("ℹ️ About", open=False):
llava_status = "Enabled βœ“" if USE_LLAVA else "Disabled"
gr.Markdown(f"""
**System Information:**
- Device: `{device}`
- Indexed Images: `{faiss_index.ntotal:,}`
- Feature Dimensions: `2048` (RCNN: 1024 + LLaVA: 1024)
- Vision Model: Faster R-CNN ResNet-50
- Language-Vision Model: LLaVA Phi-3-Mini (lightweight ~4GB)
- LLaVA Status: `{llava_status}`
- Similarity Metric: L2 Distance
- Memory Usage: ~6-8GB (optimized)
*Using LLaVA Phi-3-Mini for efficient vision-language understanding.*
""")
with gr.Column(scale=2):
output_gallery = gr.Gallery(
label="🎯 Similar Images",
columns=3,
rows=2,
height="auto",
object_fit="contain",
show_label=True
)
status_text = gr.Textbox(
label="Status",
placeholder="Upload an image and click search...",
interactive=False
)
# Example images section
example_dir = os.path.join(IMAGE_BASE_DIR, "examples")
if os.path.exists(example_dir):
example_images = [os.path.join(example_dir, f) for f in os.listdir(example_dir)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:5]
if example_images:
gr.Markdown("### πŸ’‘ Try These Examples:")
gr.Examples(
examples=[[img] for img in example_images],
inputs=input_image
)
def search_with_status(image, k):
if image is None:
return [], "⚠️ Please upload an image first"
results = search_similar_images(image, k)
if results:
return results, f"βœ… Found {len(results)} similar images"
else:
return [], "❌ No results found or error occurred"
search_btn.click(
fn=search_with_status,
inputs=[input_image, top_k_slider],
outputs=[output_gallery, status_text]
)
# Launch app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", # Allow external connections
server_port=7860,
share=True # Create public Gradio link
)