Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| # --- 1. SETUP: PATHS, PATCHES, AND MODEL DEFINITIONS --- | |
| # This section ensures the script can load your BLIP-2 model correctly. | |
| # --- PATH SETUP & MONKEY PATCHES (Copied from previous scripts) --- | |
| try: | |
| script_dir = os.path.dirname(__file__) | |
| except NameError: | |
| script_dir = os.getcwd() | |
| path_to_project_root = os.path.abspath(os.path.join(script_dir, "..")) | |
| path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS") | |
| if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))): | |
| path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS" | |
| sys.path.insert(0, path_to_lavis_parent_dir) | |
| from lavis.models import load_model_and_preprocess | |
| from lavis.models.blip2_models.blip2_qformer import Blip2Qformer | |
| import inspect | |
| _original_torch_load_state_dict = nn.Module.load_state_dict | |
| def patched_load_state_dict(self, state_dict, strict=True, assign=False): | |
| if isinstance(self, Blip2Qformer): | |
| model_state_dict = self.state_dict() | |
| for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]: | |
| if key in state_dict and key in model_state_dict and state_dict[key].shape[0] != model_state_dict[key].shape[0]: | |
| state_dict[key] = state_dict[key].narrow(0, 0, model_state_dict[key].shape[0]) | |
| if any(p.is_meta for p in self.parameters()): assign = True | |
| if 'assign' in inspect.signature(_original_torch_load_state_dict).parameters: | |
| return _original_torch_load_state_dict(self, state_dict, strict=strict, assign=assign) | |
| return _original_torch_load_state_dict(self, state_dict, strict=strict) | |
| nn.Module.load_state_dict = patched_load_state_dict | |
| print("INFO: Patches for LAVIS are active.") | |
| # --- Your Adapter Model Definition --- | |
| class MyModel(nn.Module): | |
| def __init__(self): super(MyModel, self).__init__(); self.scaling_layer = TrainableEltwiseLayer() | |
| def forward(self,x): return self.scaling_layer(x) | |
| class TrainableEltwiseLayer(nn.Module): | |
| def __init__(self): super(TrainableEltwiseLayer, self).__init__(); self.weights = nn.Parameter(torch.ones(1, 256)) | |
| def forward(self, x): return x * self.weights | |
| # --- Import and Setup SAM --- | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| # --- 2. GLOBAL SETUP: LOAD MODELS, DEFINE PROMPTS & COLORS --- | |
| # This happens once when the app starts. | |
| DEVICE = torch.device("cuda:9" if torch.cuda.is_available() else "cpu") | |
| SAM_CHECKPOINT_PATH = "sam_vit_h_4b8939.pth" | |
| ADAPTER_PATH = "global_adapter_model.pth" | |
| print(f"Using device: {DEVICE}") | |
| # Load BLIP-2 Model | |
| print("Loading base BLIP-2 model (gen3_322_840)...") | |
| BASE_MODEL, VIS_PROCESSORS, TEXT_PROCESSORS = load_model_and_preprocess( | |
| name="blip2", model_type="gen3_322_840", is_eval=True, device=DEVICE | |
| ) | |
| print("Loading fine-tuned adapter model...") | |
| ADAPTER_MODEL = MyModel().to(DEVICE) | |
| if os.path.exists(ADAPTER_PATH): | |
| ADAPTER_MODEL.load_state_dict(torch.load(ADAPTER_PATH, map_location=DEVICE)) | |
| print(f"Successfully loaded fine-tuned adapter from '{ADAPTER_PATH}'.") | |
| else: | |
| print(f"WARNING: Adapter model not found. Using an untrained adapter.") | |
| ADAPTER_MODEL.eval() | |
| # Load SAM Model | |
| print(f"Loading Segment Anything Model from {SAM_CHECKPOINT_PATH}...") | |
| sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT_PATH).to(DEVICE) | |
| SAM_GENERATOR = SamAutomaticMaskGenerator(sam) | |
| print("All models loaded successfully.") | |
| # Define the full list of prompts and colors | |
| CLASSIFICATION_PROMPTS = { | |
| "Red Light": "a photo of a red traffic light", | |
| "Yellow Light": "a photo of a yellow traffic light", | |
| "Green Light": "a photo of a green traffic light", | |
| "Black/Off Light": "a photo of a traffic light that is off or unlit", | |
| "Multibulb Light": "a photo of a traffic light with multiple bulbs", | |
| "Countdown Timer": "a photo of a traffic light with a digital countdown timer", | |
| "Left and U-turn Sign": "a photo of a traffic sign with a left arrow and a U-turn arrow", | |
| "Left and Straight Sign": "a photo of a traffic sign with a left arrow and a straight arrow", | |
| "Multi-shape Sign": "a photo of a traffic sign with multiple shapes or complex symbols", | |
| "Pedestrian": "a photo of a pedestrian or a person walking in the street", | |
| "Bicycle": "a photo of a bicycle or a person on a bike", | |
| "Unknown": "an unrecognizable or unknown object" # A generic 'Unknown' | |
| } | |
| colors = plt.cm.get_cmap('tab20', len(CLASSIFICATION_PROMPTS)) | |
| COLOR_PALETTE = {label: tuple(int(c * 255) for c in colors(i)[:3]) for i, label in enumerate(CLASSIFICATION_PROMPTS.keys())} | |
| # --- 3. HELPER AND CORE GRADIO FUNCTIONS --- | |
| def create_legend_image(palette): | |
| """Creates a PIL image for the color legend.""" | |
| item_height, text_offset, margin = 30, 5, 10 | |
| width = 300 | |
| height = len(palette) * item_height + 2 * margin | |
| legend = Image.new('RGB', (width, height), 'white') | |
| draw = ImageDraw.Draw(legend) | |
| try: font = ImageFont.truetype("arial.ttf", 14) | |
| except IOError: font = ImageFont.load_default() | |
| for i, (label, color) in enumerate(palette.items()): | |
| top = margin + i * item_height | |
| draw.rectangle([margin, top, margin + 40, top + 20], fill=color) | |
| draw.text((margin + 50, top + text_offset), label, fill='black', font=font) | |
| return legend | |
| def classify_crop(crop_image): | |
| """Classifies a single cropped image using the globally loaded models.""" | |
| image_processed = VIS_PROCESSORS["eval"](crop_image).unsqueeze(0).to(DEVICE) | |
| text_processed = [TEXT_PROCESSORS["eval"](s) for s in CLASSIFICATION_PROMPTS.values()] | |
| with torch.no_grad(): | |
| image_features = BASE_MODEL.extract_features({"image": image_processed}, mode="image").image_embeds_proj[:, 0, :] | |
| text_features = BASE_MODEL.extract_features({"text_input": text_processed}, mode="text").text_embeds_proj[:, 0, :] | |
| scaled_image_features = ADAPTER_MODEL(image_features) | |
| scaled_text_features = ADAPTER_MODEL(text_features) | |
| logits = F.normalize(scaled_image_features) @ F.normalize(scaled_text_features).t() | |
| probabilities = logits.sigmoid().squeeze() | |
| return {label: prob.item() for label, prob in zip(CLASSIFICATION_PROMPTS.keys(), probabilities)} | |
| def segment_image_gradio(input_image, confidence_threshold): | |
| """The main function that Gradio will call to process an image.""" | |
| if input_image is None: | |
| raise gr.Error("Please upload an image first!") | |
| image_np = np.array(input_image.convert("RGB")) | |
| # 1. Generate all masks with SAM | |
| print("Generating masks with SAM...") | |
| masks = SAM_GENERATOR.generate(image_np) | |
| print(f"Found {len(masks)} potential masks.") | |
| # Prepare output canvas | |
| segmentation_layer = np.zeros_like(image_np, dtype=np.uint8) | |
| score_map = np.zeros((image_np.shape[0], image_np.shape[1]), dtype=np.float32) | |
| # 2. Classify each mask and build the segmentation map | |
| print("Classifying each mask...") | |
| for mask_data in sorted(masks, key=lambda x: x['area']): | |
| mask = mask_data['segmentation'] | |
| [x, y, w, h] = mask_data['bbox'] | |
| crop = input_image.crop((x, y, x + w, y + h)) | |
| scores = classify_crop(crop) | |
| best_label = max(scores, key=scores.get) | |
| best_score = scores[best_label] | |
| if best_score > confidence_threshold: | |
| pixels_to_update = (mask) & (best_score > score_map) | |
| segmentation_layer[pixels_to_update] = COLOR_PALETTE[best_label] | |
| score_map[pixels_to_update] = best_score | |
| # 3. Blend the original image with the segmentation layer | |
| blended_image = input_image.convert("RGBA") | |
| segmentation_image = Image.fromarray(segmentation_layer).convert("RGBA") | |
| segmentation_image.putalpha(128) # Make overlay semi-transparent | |
| final_image = Image.alpha_composite(blended_image, segmentation_image) | |
| print("Segmentation complete.") | |
| return final_image.convert("RGB") | |
| # --- 4. BUILD THE GRADIO INTERFACE --- | |
| legend_pil = create_legend_image(COLOR_PALETTE) | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Zero-Shot Segmenter") as demo: | |
| gr.Markdown("# Zero-Shot Segmentation with SAM and Fine-Tuned BLIP-2") | |
| gr.Markdown("Upload an image to segment it based on natural language descriptions. Models are pre-loaded.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| confidence_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold") | |
| submit_btn = gr.Button("Generate Segmentation", variant="primary") | |
| gr.Image(value=legend_pil, label="Color Legend", interactive=False) | |
| with gr.Column(scale=2): | |
| output_image = gr.Image(label="Segmented Output", type="pil") | |
| submit_btn.click( | |
| fn=segment_image_gradio, | |
| inputs=[input_image, confidence_slider], | |
| outputs=output_image | |
| ) | |
| # --- 5. LAUNCH THE APP --- | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=8008, share=True) # Using a different port |