import os import sys from PIL import Image CACHE_DIR_NAME = "model_cache" script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in locals() else os.getcwd() local_cache_path = os.path.join(script_dir, CACHE_DIR_NAME) os.makedirs(local_cache_path, exist_ok=True) os.environ['HF_HOME'] = local_cache_path os.environ['LAVIS_CACHE_ROOT'] = local_cache_path print(f"--- CUSTOM CACHE ACTIVATED ---") print(f"Hugging Face and LAVIS cache redirected to: {local_cache_path}") print(f"---------------------------------") os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7,8,9' os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from transformers import AutoModel, AutoProcessor try: script_dir = os.path.dirname(__file__) except NameError: script_dir = os.getcwd() path_to_project_root = os.path.abspath(os.path.join(script_dir, "..")) path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS") if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))): print(f"Warning: Relative LAVIS path {path_to_lavis_parent_dir} not found or invalid. Trying historical path.") path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS" if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))): print(f"ERROR: Could not find a valid LAVIS package location. Last tried: {path_to_lavis_parent_dir}") sys.exit(1) sys.path.insert(0, path_to_lavis_parent_dir) print(f"INFO: Added {path_to_lavis_parent_dir} to sys.path for LAVIS.") import torch.distributions.constraints as constraints from transformers.modeling_utils import PreTrainedModel import inspect from lavis.models.blip2_models.blip2_qformer import Blip2Qformer if hasattr(constraints, '_PositiveDefinite') and hasattr(constraints._PositiveDefinite, 'check'): original_positive_definite_check = constraints._PositiveDefinite.check def patched_positive_definite_check(self, value): if isinstance(value, torch.Tensor) and value.is_meta: return torch.ones_like(value, dtype=torch.bool, device='cuda:6') return original_positive_definite_check(self, value) constraints._PositiveDefinite.check = patched_positive_definite_check print("INFO: Patched torch.distributions.constraints._PositiveDefinite.check for meta tensors.") if hasattr(PreTrainedModel, '_init_added_embeddings_weights_with_mean'): original_init_added_weights = PreTrainedModel._init_added_embeddings_weights_with_mean def patched_init_added_weights(self, new_embeddings_module, old_embeddings_module, num_added_tokens, *args, **kwargs): return PreTrainedModel._init_added_embeddings_weights_with_mean = patched_init_added_weights print("INFO: Patched PreTrainedModel._init_added_embeddings_weights_with_mean for meta tensors.") print("--- Applying robust patch for load_state_dict (Vocab Mismatch & Meta Devices) ---") _original_torch_load_state_dict = nn.Module.load_state_dict def patched_load_state_dict(self, state_dict, strict=True, assign=False): if isinstance(self, Blip2Qformer): model_state_dict = self.state_dict() for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]: if key in state_dict and key in model_state_dict and state_dict[key].shape[0] != model_state_dict[key].shape[0]: ckpt_tensor = state_dict[key] model_tensor = model_state_dict[key] model_vocab_size = model_tensor.shape[0] state_dict[key] = ckpt_tensor.narrow(0, 0, model_vocab_size) if any(p.is_meta for p in self.parameters()): assign = True if 'assign' in inspect.signature(_original_torch_load_state_dict).parameters: return _original_torch_load_state_dict(self, state_dict, strict=strict, assign=assign) return _original_torch_load_state_dict(self, state_dict, strict=strict) nn.Module.load_state_dict = patched_load_state_dict print("INFO: Monkey-patched nn.Module.load_state_dict with robust handler.") from lavis.models import load_model_and_preprocess class MyModel(nn.Module): def __init__(self, input_dim=256, hidden_dim=128, output_dim=256): super(MyModel, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x DEVICE = torch.device("cuda:6" if torch.cuda.is_available() else "cpu") GLOBAL_MODEL_PATH = "global_adapter_model.pth" print(f"Using device: {DEVICE}") print("Loading base BLIP-2 model (gen3_322_840)...") BASE_MODEL, VIS_PROCESSORS, TEXT_PROCESSORS = load_model_and_preprocess( name="blip2", model_type="gen3_322_840", is_eval=True, device=DEVICE ) print("Base model loaded successfully.") print("Loading fine-tuned adapter model...") ADAPTER_MODEL = MyModel().to(DEVICE) if os.path.exists(GLOBAL_MODEL_PATH): ADAPTER_MODEL.load_state_dict(torch.load(GLOBAL_MODEL_PATH, map_location=DEVICE)) print(f"Successfully loaded fine-tuned adapter from '{GLOBAL_MODEL_PATH}'.") else: print(f"WARNING: No saved adapter model found.") ADAPTER_MODEL.eval() CLASSIFICATION_PROMPTS = { # --- Traffic Light States --- # "Red Light On": "a photo of a traffic signal with the red light illuminated", # "Yellow Light On": "a photo of a traffic signal with the yellow light glowing", # "Green Light On": "a photo of a traffic signal with a shining green light", # "All Lights Off": "a photo of a traffic signal where all the lights are off", # --- Pedestrian & Bicycle Symbols --- # "Pedestrian Signal": "a pedestrian crosswalk signal showing the illuminated symbol of a person, either walking or standing still.", # "Bicycle Symbol": "the illuminated symbol of a bicycle on a green traffic signal for bikes", # --- Regulatory Signs --- "Speed Limit Sign": "speed limit sign", # "End of Speed Limit": "a grey or white circular sign with black diagonal lines crossing out a number, indicating the end of a speed zone", # "Weight Limit Sign": "a circular traffic sign with a red border showing a weight limit in tonnes, often with an icon of a truck", "No Parking Sign": "No Parking Sign", "No Entry Sign": "No Entry Sign", # "End of All Prohibitions": "a white or grey circular sign with multiple black diagonal lines, indicating the end of all previous prohibitions", # --- Highway Signs --- # "Highway Start": "a rectangular blue or green sign with a white pictogram of a motorway or highway, indicating the start of a major road", # "Highway End": "a rectangular blue or green highway sign with a diagonal red line across it, indicating the end of the motorway", # --- Mandatory & Warning Signs --- # "Roundabout Sign": "a circular sign with three white arrows chasing each other in a circle, indicating a mandatory roundabout ahead", "Road Works Warning": "Road Works Warning", # --- Other Signs & Physical Properties --- # "Left and U-turn Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a U-turn arrow symbol", # "Left and Straight Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a straight-ahead arrow symbol", # "Multibulb (LED Grid) Light": "a traffic light where the main signal is made from a grid of many small LED bulbs", # --- Error States & Distractors --- # "Malfunctioning Light": "a malfunctioning or broken traffic light showing a distorted or unrecognizable symbol", # "Real Person Walking": "a photo of an actual person walking on the street, not a symbol", # "Real Bicycle": "a photo of a real bicycle being ridden on the road, not a symbol" } def classify_image_from_upload(input_image: Image.Image): if input_image is None: return {}, {}, "" raw_image = input_image.convert("RGB") image_processed = VIS_PROCESSORS["eval"](raw_image).unsqueeze(0).to(DEVICE) text_prompt_list = list(CLASSIFICATION_PROMPTS.values()) text_processed = [TEXT_PROCESSORS["eval"](s) for s in text_prompt_list] with torch.no_grad(): image_features = BASE_MODEL.extract_features({"image": image_processed}, mode="image").image_embeds_proj[:, 0, :] text_features = BASE_MODEL.extract_features({"text_input": text_processed}, mode="text").text_embeds_proj[:, 0, :] image_norm = F.normalize(image_features, p=2, dim=-1) text_norm = F.normalize(text_features, p=2, dim=-1) similarities = (image_norm @ text_norm.t()).squeeze() probabilities = torch.softmax(similarities, dim=-1).squeeze() sim_results = {label: sim.item() for label, sim in zip(CLASSIFICATION_PROMPTS.keys(), similarities)} prob_results_dict = {label: prob.item() for label, prob in zip(CLASSIFICATION_PROMPTS.keys(), probabilities)} sorted_probs = sorted(prob_results_dict.items(), key=lambda item: item[1], reverse=True) prob_string_output = "\n".join([f"{label}: {prob:.4f}" for label, prob in sorted_probs]) return sim_results, prob_results_dict, prob_string_output with gr.Blocks(theme=gr.themes.Soft(), title="Image Classifier") as demo: gr.Markdown("# Cityscapes Zero-Shot Classifier") gr.Markdown("Upload an image to see the model's output as both charts and numerical points.") with gr.Row(variant="panel"): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Image") classify_button = gr.Button("Classify Image", variant="primary") with gr.Column(scale=2): gr.Markdown("### Cosine Similarity (Chart)") similarity_output_chart = gr.Label(label="Similarity Scores", num_top_classes=len(CLASSIFICATION_PROMPTS)) with gr.Row(): with gr.Column(): gr.Markdown("### Probability - Softmax (Chart)") probability_output_chart = gr.Label(label="Probabilities", num_top_classes=len(CLASSIFICATION_PROMPTS)) with gr.Column(): gr.Markdown("### Probability - Softmax (Points)") probability_output_points = gr.Textbox( label="Numerical Probabilities", lines=len(CLASSIFICATION_PROMPTS), interactive=False ) classify_button.click( fn=classify_image_from_upload, inputs=image_input, outputs=[similarity_output_chart, probability_output_chart, probability_output_points] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=8008, share=True)