event_retrieval / zero_shot_segmentation.py
sanskar753's picture
Upload folder using huggingface_hub
02d3a85 verified
import os
import sys
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import matplotlib.pyplot as plt
# --- 1. SETUP: PATHS, PATCHES, AND MODEL DEFINITIONS ---
# This section ensures the script can load your BLIP-2 model correctly.
# --- PATH SETUP & MONKEY PATCHES (Copied from previous scripts) ---
try:
script_dir = os.path.dirname(__file__)
except NameError:
script_dir = os.getcwd()
path_to_project_root = os.path.abspath(os.path.join(script_dir, ".."))
path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS")
if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))):
path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS"
sys.path.insert(0, path_to_lavis_parent_dir)
from lavis.models import load_model_and_preprocess
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
import inspect
_original_torch_load_state_dict = nn.Module.load_state_dict
def patched_load_state_dict(self, state_dict, strict=True, assign=False):
if isinstance(self, Blip2Qformer):
model_state_dict = self.state_dict()
for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]:
if key in state_dict and key in model_state_dict and state_dict[key].shape[0] != model_state_dict[key].shape[0]:
state_dict[key] = state_dict[key].narrow(0, 0, model_state_dict[key].shape[0])
if any(p.is_meta for p in self.parameters()): assign = True
if 'assign' in inspect.signature(_original_torch_load_state_dict).parameters:
return _original_torch_load_state_dict(self, state_dict, strict=strict, assign=assign)
return _original_torch_load_state_dict(self, state_dict, strict=strict)
nn.Module.load_state_dict = patched_load_state_dict
print("INFO: Patches for LAVIS are active.")
# --- Your Adapter Model Definition ---
class MyModel(nn.Module):
def __init__(self): super(MyModel, self).__init__(); self.scaling_layer = TrainableEltwiseLayer()
def forward(self,x): return self.scaling_layer(x)
class TrainableEltwiseLayer(nn.Module):
def __init__(self): super(TrainableEltwiseLayer, self).__init__(); self.weights = nn.Parameter(torch.ones(1, 256))
def forward(self, x): return x * self.weights
# --- Import and Setup SAM ---
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
# --- 2. GLOBAL SETUP: LOAD MODELS, DEFINE PROMPTS & COLORS ---
# This happens once when the app starts.
DEVICE = torch.device("cuda:9" if torch.cuda.is_available() else "cpu")
SAM_CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
ADAPTER_PATH = "global_adapter_model.pth"
print(f"Using device: {DEVICE}")
# Load BLIP-2 Model
print("Loading base BLIP-2 model (gen3_322_840)...")
BASE_MODEL, VIS_PROCESSORS, TEXT_PROCESSORS = load_model_and_preprocess(
name="blip2", model_type="gen3_322_840", is_eval=True, device=DEVICE
)
print("Loading fine-tuned adapter model...")
ADAPTER_MODEL = MyModel().to(DEVICE)
if os.path.exists(ADAPTER_PATH):
ADAPTER_MODEL.load_state_dict(torch.load(ADAPTER_PATH, map_location=DEVICE))
print(f"Successfully loaded fine-tuned adapter from '{ADAPTER_PATH}'.")
else:
print(f"WARNING: Adapter model not found. Using an untrained adapter.")
ADAPTER_MODEL.eval()
# Load SAM Model
print(f"Loading Segment Anything Model from {SAM_CHECKPOINT_PATH}...")
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT_PATH).to(DEVICE)
SAM_GENERATOR = SamAutomaticMaskGenerator(sam)
print("All models loaded successfully.")
# Define the full list of prompts and colors
CLASSIFICATION_PROMPTS = {
"Red Light": "a photo of a red traffic light",
"Yellow Light": "a photo of a yellow traffic light",
"Green Light": "a photo of a green traffic light",
"Black/Off Light": "a photo of a traffic light that is off or unlit",
"Multibulb Light": "a photo of a traffic light with multiple bulbs",
"Countdown Timer": "a photo of a traffic light with a digital countdown timer",
"Left and U-turn Sign": "a photo of a traffic sign with a left arrow and a U-turn arrow",
"Left and Straight Sign": "a photo of a traffic sign with a left arrow and a straight arrow",
"Multi-shape Sign": "a photo of a traffic sign with multiple shapes or complex symbols",
"Pedestrian": "a photo of a pedestrian or a person walking in the street",
"Bicycle": "a photo of a bicycle or a person on a bike",
"Unknown": "an unrecognizable or unknown object" # A generic 'Unknown'
}
colors = plt.cm.get_cmap('tab20', len(CLASSIFICATION_PROMPTS))
COLOR_PALETTE = {label: tuple(int(c * 255) for c in colors(i)[:3]) for i, label in enumerate(CLASSIFICATION_PROMPTS.keys())}
# --- 3. HELPER AND CORE GRADIO FUNCTIONS ---
def create_legend_image(palette):
"""Creates a PIL image for the color legend."""
item_height, text_offset, margin = 30, 5, 10
width = 300
height = len(palette) * item_height + 2 * margin
legend = Image.new('RGB', (width, height), 'white')
draw = ImageDraw.Draw(legend)
try: font = ImageFont.truetype("arial.ttf", 14)
except IOError: font = ImageFont.load_default()
for i, (label, color) in enumerate(palette.items()):
top = margin + i * item_height
draw.rectangle([margin, top, margin + 40, top + 20], fill=color)
draw.text((margin + 50, top + text_offset), label, fill='black', font=font)
return legend
def classify_crop(crop_image):
"""Classifies a single cropped image using the globally loaded models."""
image_processed = VIS_PROCESSORS["eval"](crop_image).unsqueeze(0).to(DEVICE)
text_processed = [TEXT_PROCESSORS["eval"](s) for s in CLASSIFICATION_PROMPTS.values()]
with torch.no_grad():
image_features = BASE_MODEL.extract_features({"image": image_processed}, mode="image").image_embeds_proj[:, 0, :]
text_features = BASE_MODEL.extract_features({"text_input": text_processed}, mode="text").text_embeds_proj[:, 0, :]
scaled_image_features = ADAPTER_MODEL(image_features)
scaled_text_features = ADAPTER_MODEL(text_features)
logits = F.normalize(scaled_image_features) @ F.normalize(scaled_text_features).t()
probabilities = logits.sigmoid().squeeze()
return {label: prob.item() for label, prob in zip(CLASSIFICATION_PROMPTS.keys(), probabilities)}
def segment_image_gradio(input_image, confidence_threshold):
"""The main function that Gradio will call to process an image."""
if input_image is None:
raise gr.Error("Please upload an image first!")
image_np = np.array(input_image.convert("RGB"))
# 1. Generate all masks with SAM
print("Generating masks with SAM...")
masks = SAM_GENERATOR.generate(image_np)
print(f"Found {len(masks)} potential masks.")
# Prepare output canvas
segmentation_layer = np.zeros_like(image_np, dtype=np.uint8)
score_map = np.zeros((image_np.shape[0], image_np.shape[1]), dtype=np.float32)
# 2. Classify each mask and build the segmentation map
print("Classifying each mask...")
for mask_data in sorted(masks, key=lambda x: x['area']):
mask = mask_data['segmentation']
[x, y, w, h] = mask_data['bbox']
crop = input_image.crop((x, y, x + w, y + h))
scores = classify_crop(crop)
best_label = max(scores, key=scores.get)
best_score = scores[best_label]
if best_score > confidence_threshold:
pixels_to_update = (mask) & (best_score > score_map)
segmentation_layer[pixels_to_update] = COLOR_PALETTE[best_label]
score_map[pixels_to_update] = best_score
# 3. Blend the original image with the segmentation layer
blended_image = input_image.convert("RGBA")
segmentation_image = Image.fromarray(segmentation_layer).convert("RGBA")
segmentation_image.putalpha(128) # Make overlay semi-transparent
final_image = Image.alpha_composite(blended_image, segmentation_image)
print("Segmentation complete.")
return final_image.convert("RGB")
# --- 4. BUILD THE GRADIO INTERFACE ---
legend_pil = create_legend_image(COLOR_PALETTE)
with gr.Blocks(theme=gr.themes.Soft(), title="Zero-Shot Segmenter") as demo:
gr.Markdown("# Zero-Shot Segmentation with SAM and Fine-Tuned BLIP-2")
gr.Markdown("Upload an image to segment it based on natural language descriptions. Models are pre-loaded.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold")
submit_btn = gr.Button("Generate Segmentation", variant="primary")
gr.Image(value=legend_pil, label="Color Legend", interactive=False)
with gr.Column(scale=2):
output_image = gr.Image(label="Segmented Output", type="pil")
submit_btn.click(
fn=segment_image_gradio,
inputs=[input_image, confidence_slider],
outputs=output_image
)
# --- 5. LAUNCH THE APP ---
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=8008, share=True) # Using a different port