import os import sys from PIL import Image, ImageDraw, ImageFont import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr import matplotlib.pyplot as plt # --- 1. SETUP: PATHS, PATCHES, AND MODEL DEFINITIONS --- # This section ensures the script can load your BLIP-2 model correctly. # --- PATH SETUP & MONKEY PATCHES (Copied from previous scripts) --- try: script_dir = os.path.dirname(__file__) except NameError: script_dir = os.getcwd() path_to_project_root = os.path.abspath(os.path.join(script_dir, "..")) path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS") if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))): path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS" sys.path.insert(0, path_to_lavis_parent_dir) from lavis.models import load_model_and_preprocess from lavis.models.blip2_models.blip2_qformer import Blip2Qformer import inspect _original_torch_load_state_dict = nn.Module.load_state_dict def patched_load_state_dict(self, state_dict, strict=True, assign=False): if isinstance(self, Blip2Qformer): model_state_dict = self.state_dict() for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]: if key in state_dict and key in model_state_dict and state_dict[key].shape[0] != model_state_dict[key].shape[0]: state_dict[key] = state_dict[key].narrow(0, 0, model_state_dict[key].shape[0]) if any(p.is_meta for p in self.parameters()): assign = True if 'assign' in inspect.signature(_original_torch_load_state_dict).parameters: return _original_torch_load_state_dict(self, state_dict, strict=strict, assign=assign) return _original_torch_load_state_dict(self, state_dict, strict=strict) nn.Module.load_state_dict = patched_load_state_dict print("INFO: Patches for LAVIS are active.") # --- Your Adapter Model Definition --- class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__(); self.scaling_layer = TrainableEltwiseLayer() def forward(self,x): return self.scaling_layer(x) class TrainableEltwiseLayer(nn.Module): def __init__(self): super(TrainableEltwiseLayer, self).__init__(); self.weights = nn.Parameter(torch.ones(1, 256)) def forward(self, x): return x * self.weights # --- Import and Setup SAM --- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # --- 2. GLOBAL SETUP: LOAD MODELS, DEFINE PROMPTS & COLORS --- # This happens once when the app starts. DEVICE = torch.device("cuda:9" if torch.cuda.is_available() else "cpu") SAM_CHECKPOINT_PATH = "sam_vit_h_4b8939.pth" ADAPTER_PATH = "global_adapter_model.pth" print(f"Using device: {DEVICE}") # Load BLIP-2 Model print("Loading base BLIP-2 model (gen3_322_840)...") BASE_MODEL, VIS_PROCESSORS, TEXT_PROCESSORS = load_model_and_preprocess( name="blip2", model_type="gen3_322_840", is_eval=True, device=DEVICE ) print("Loading fine-tuned adapter model...") ADAPTER_MODEL = MyModel().to(DEVICE) if os.path.exists(ADAPTER_PATH): ADAPTER_MODEL.load_state_dict(torch.load(ADAPTER_PATH, map_location=DEVICE)) print(f"Successfully loaded fine-tuned adapter from '{ADAPTER_PATH}'.") else: print(f"WARNING: Adapter model not found. Using an untrained adapter.") ADAPTER_MODEL.eval() # Load SAM Model print(f"Loading Segment Anything Model from {SAM_CHECKPOINT_PATH}...") sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT_PATH).to(DEVICE) SAM_GENERATOR = SamAutomaticMaskGenerator(sam) print("All models loaded successfully.") # Define the full list of prompts and colors CLASSIFICATION_PROMPTS = { "Red Light": "a photo of a red traffic light", "Yellow Light": "a photo of a yellow traffic light", "Green Light": "a photo of a green traffic light", "Black/Off Light": "a photo of a traffic light that is off or unlit", "Multibulb Light": "a photo of a traffic light with multiple bulbs", "Countdown Timer": "a photo of a traffic light with a digital countdown timer", "Left and U-turn Sign": "a photo of a traffic sign with a left arrow and a U-turn arrow", "Left and Straight Sign": "a photo of a traffic sign with a left arrow and a straight arrow", "Multi-shape Sign": "a photo of a traffic sign with multiple shapes or complex symbols", "Pedestrian": "a photo of a pedestrian or a person walking in the street", "Bicycle": "a photo of a bicycle or a person on a bike", "Unknown": "an unrecognizable or unknown object" # A generic 'Unknown' } colors = plt.cm.get_cmap('tab20', len(CLASSIFICATION_PROMPTS)) COLOR_PALETTE = {label: tuple(int(c * 255) for c in colors(i)[:3]) for i, label in enumerate(CLASSIFICATION_PROMPTS.keys())} # --- 3. HELPER AND CORE GRADIO FUNCTIONS --- def create_legend_image(palette): """Creates a PIL image for the color legend.""" item_height, text_offset, margin = 30, 5, 10 width = 300 height = len(palette) * item_height + 2 * margin legend = Image.new('RGB', (width, height), 'white') draw = ImageDraw.Draw(legend) try: font = ImageFont.truetype("arial.ttf", 14) except IOError: font = ImageFont.load_default() for i, (label, color) in enumerate(palette.items()): top = margin + i * item_height draw.rectangle([margin, top, margin + 40, top + 20], fill=color) draw.text((margin + 50, top + text_offset), label, fill='black', font=font) return legend def classify_crop(crop_image): """Classifies a single cropped image using the globally loaded models.""" image_processed = VIS_PROCESSORS["eval"](crop_image).unsqueeze(0).to(DEVICE) text_processed = [TEXT_PROCESSORS["eval"](s) for s in CLASSIFICATION_PROMPTS.values()] with torch.no_grad(): image_features = BASE_MODEL.extract_features({"image": image_processed}, mode="image").image_embeds_proj[:, 0, :] text_features = BASE_MODEL.extract_features({"text_input": text_processed}, mode="text").text_embeds_proj[:, 0, :] scaled_image_features = ADAPTER_MODEL(image_features) scaled_text_features = ADAPTER_MODEL(text_features) logits = F.normalize(scaled_image_features) @ F.normalize(scaled_text_features).t() probabilities = logits.sigmoid().squeeze() return {label: prob.item() for label, prob in zip(CLASSIFICATION_PROMPTS.keys(), probabilities)} def segment_image_gradio(input_image, confidence_threshold): """The main function that Gradio will call to process an image.""" if input_image is None: raise gr.Error("Please upload an image first!") image_np = np.array(input_image.convert("RGB")) # 1. Generate all masks with SAM print("Generating masks with SAM...") masks = SAM_GENERATOR.generate(image_np) print(f"Found {len(masks)} potential masks.") # Prepare output canvas segmentation_layer = np.zeros_like(image_np, dtype=np.uint8) score_map = np.zeros((image_np.shape[0], image_np.shape[1]), dtype=np.float32) # 2. Classify each mask and build the segmentation map print("Classifying each mask...") for mask_data in sorted(masks, key=lambda x: x['area']): mask = mask_data['segmentation'] [x, y, w, h] = mask_data['bbox'] crop = input_image.crop((x, y, x + w, y + h)) scores = classify_crop(crop) best_label = max(scores, key=scores.get) best_score = scores[best_label] if best_score > confidence_threshold: pixels_to_update = (mask) & (best_score > score_map) segmentation_layer[pixels_to_update] = COLOR_PALETTE[best_label] score_map[pixels_to_update] = best_score # 3. Blend the original image with the segmentation layer blended_image = input_image.convert("RGBA") segmentation_image = Image.fromarray(segmentation_layer).convert("RGBA") segmentation_image.putalpha(128) # Make overlay semi-transparent final_image = Image.alpha_composite(blended_image, segmentation_image) print("Segmentation complete.") return final_image.convert("RGB") # --- 4. BUILD THE GRADIO INTERFACE --- legend_pil = create_legend_image(COLOR_PALETTE) with gr.Blocks(theme=gr.themes.Soft(), title="Zero-Shot Segmenter") as demo: gr.Markdown("# Zero-Shot Segmentation with SAM and Fine-Tuned BLIP-2") gr.Markdown("Upload an image to segment it based on natural language descriptions. Models are pre-loaded.") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload Image") confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold") submit_btn = gr.Button("Generate Segmentation", variant="primary") gr.Image(value=legend_pil, label="Color Legend", interactive=False) with gr.Column(scale=2): output_image = gr.Image(label="Segmented Output", type="pil") submit_btn.click( fn=segment_image_gradio, inputs=[input_image, confidence_slider], outputs=output_image ) # --- 5. LAUNCH THE APP --- if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=8008, share=True) # Using a different port