Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| from PIL import Image | |
| CACHE_DIR_NAME = "model_cache" | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in locals() else os.getcwd() | |
| local_cache_path = os.path.join(script_dir, CACHE_DIR_NAME) | |
| os.makedirs(local_cache_path, exist_ok=True) | |
| os.environ['HF_HOME'] = local_cache_path | |
| os.environ['LAVIS_CACHE_ROOT'] = local_cache_path | |
| print(f"--- CUSTOM CACHE ACTIVATED ---") | |
| print(f"Hugging Face and LAVIS cache redirected to: {local_cache_path}") | |
| print(f"---------------------------------") | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7,8,9' | |
| os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import AutoModel, AutoProcessor | |
| try: | |
| script_dir = os.path.dirname(__file__) | |
| except NameError: | |
| script_dir = os.getcwd() | |
| path_to_project_root = os.path.abspath(os.path.join(script_dir, "..")) | |
| path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS") | |
| if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))): | |
| print(f"Warning: Relative LAVIS path {path_to_lavis_parent_dir} not found or invalid. Trying historical path.") | |
| path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS" | |
| if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))): | |
| print(f"ERROR: Could not find a valid LAVIS package location. Last tried: {path_to_lavis_parent_dir}") | |
| sys.exit(1) | |
| sys.path.insert(0, path_to_lavis_parent_dir) | |
| print(f"INFO: Added {path_to_lavis_parent_dir} to sys.path for LAVIS.") | |
| import torch.distributions.constraints as constraints | |
| from transformers.modeling_utils import PreTrainedModel | |
| import inspect | |
| from lavis.models.blip2_models.blip2_qformer import Blip2Qformer | |
| if hasattr(constraints, '_PositiveDefinite') and hasattr(constraints._PositiveDefinite, 'check'): | |
| original_positive_definite_check = constraints._PositiveDefinite.check | |
| def patched_positive_definite_check(self, value): | |
| if isinstance(value, torch.Tensor) and value.is_meta: return torch.ones_like(value, dtype=torch.bool, device='cuda:6') | |
| return original_positive_definite_check(self, value) | |
| constraints._PositiveDefinite.check = patched_positive_definite_check | |
| print("INFO: Patched torch.distributions.constraints._PositiveDefinite.check for meta tensors.") | |
| if hasattr(PreTrainedModel, '_init_added_embeddings_weights_with_mean'): | |
| original_init_added_weights = PreTrainedModel._init_added_embeddings_weights_with_mean | |
| def patched_init_added_weights(self, new_embeddings_module, old_embeddings_module, num_added_tokens, *args, **kwargs): return | |
| PreTrainedModel._init_added_embeddings_weights_with_mean = patched_init_added_weights | |
| print("INFO: Patched PreTrainedModel._init_added_embeddings_weights_with_mean for meta tensors.") | |
| print("--- Applying robust patch for load_state_dict (Vocab Mismatch & Meta Devices) ---") | |
| _original_torch_load_state_dict = nn.Module.load_state_dict | |
| def patched_load_state_dict(self, state_dict, strict=True, assign=False): | |
| if isinstance(self, Blip2Qformer): | |
| model_state_dict = self.state_dict() | |
| for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]: | |
| if key in state_dict and key in model_state_dict and state_dict[key].shape[0] != model_state_dict[key].shape[0]: | |
| ckpt_tensor = state_dict[key] | |
| model_tensor = model_state_dict[key] | |
| model_vocab_size = model_tensor.shape[0] | |
| state_dict[key] = ckpt_tensor.narrow(0, 0, model_vocab_size) | |
| if any(p.is_meta for p in self.parameters()): assign = True | |
| if 'assign' in inspect.signature(_original_torch_load_state_dict).parameters: | |
| return _original_torch_load_state_dict(self, state_dict, strict=strict, assign=assign) | |
| return _original_torch_load_state_dict(self, state_dict, strict=strict) | |
| nn.Module.load_state_dict = patched_load_state_dict | |
| print("INFO: Monkey-patched nn.Module.load_state_dict with robust handler.") | |
| from lavis.models import load_model_and_preprocess | |
| class MyModel(nn.Module): | |
| def __init__(self, input_dim=256, hidden_dim=128, output_dim=256): | |
| super(MyModel, self).__init__() | |
| self.fc1 = nn.Linear(input_dim, hidden_dim) | |
| self.relu = nn.ReLU() | |
| self.fc2 = nn.Linear(hidden_dim, output_dim) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.relu(x) | |
| x = self.fc2(x) | |
| return x | |
| DEVICE = torch.device("cuda:6" if torch.cuda.is_available() else "cpu") | |
| GLOBAL_MODEL_PATH = "global_adapter_model.pth" | |
| print(f"Using device: {DEVICE}") | |
| print("Loading base BLIP-2 model (gen3_322_840)...") | |
| BASE_MODEL, VIS_PROCESSORS, TEXT_PROCESSORS = load_model_and_preprocess( | |
| name="blip2", model_type="gen3_322_840", is_eval=True, device=DEVICE | |
| ) | |
| print("Base model loaded successfully.") | |
| print("Loading fine-tuned adapter model...") | |
| ADAPTER_MODEL = MyModel().to(DEVICE) | |
| if os.path.exists(GLOBAL_MODEL_PATH): | |
| ADAPTER_MODEL.load_state_dict(torch.load(GLOBAL_MODEL_PATH, map_location=DEVICE)) | |
| print(f"Successfully loaded fine-tuned adapter from '{GLOBAL_MODEL_PATH}'.") | |
| else: | |
| print(f"WARNING: No saved adapter model found.") | |
| ADAPTER_MODEL.eval() | |
| CLASSIFICATION_PROMPTS = { | |
| # --- Traffic Light States --- | |
| # "Red Light On": "a photo of a traffic signal with the red light illuminated", | |
| # "Yellow Light On": "a photo of a traffic signal with the yellow light glowing", | |
| # "Green Light On": "a photo of a traffic signal with a shining green light", | |
| # "All Lights Off": "a photo of a traffic signal where all the lights are off", | |
| # --- Pedestrian & Bicycle Symbols --- | |
| # "Pedestrian Signal": "a pedestrian crosswalk signal showing the illuminated symbol of a person, either walking or standing still.", | |
| # "Bicycle Symbol": "the illuminated symbol of a bicycle on a green traffic signal for bikes", | |
| # --- Regulatory Signs --- | |
| "Speed Limit Sign": "speed limit sign", | |
| # "End of Speed Limit": "a grey or white circular sign with black diagonal lines crossing out a number, indicating the end of a speed zone", | |
| # "Weight Limit Sign": "a circular traffic sign with a red border showing a weight limit in tonnes, often with an icon of a truck", | |
| "No Parking Sign": "No Parking Sign", | |
| "No Entry Sign": "No Entry Sign", | |
| # "End of All Prohibitions": "a white or grey circular sign with multiple black diagonal lines, indicating the end of all previous prohibitions", | |
| # --- Highway Signs --- | |
| # "Highway Start": "a rectangular blue or green sign with a white pictogram of a motorway or highway, indicating the start of a major road", | |
| # "Highway End": "a rectangular blue or green highway sign with a diagonal red line across it, indicating the end of the motorway", | |
| # --- Mandatory & Warning Signs --- | |
| # "Roundabout Sign": "a circular sign with three white arrows chasing each other in a circle, indicating a mandatory roundabout ahead", | |
| "Road Works Warning": "Road Works Warning", | |
| # --- Other Signs & Physical Properties --- | |
| # "Left and U-turn Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a U-turn arrow symbol", | |
| # "Left and Straight Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a straight-ahead arrow symbol", | |
| # "Multibulb (LED Grid) Light": "a traffic light where the main signal is made from a grid of many small LED bulbs", | |
| # --- Error States & Distractors --- | |
| # "Malfunctioning Light": "a malfunctioning or broken traffic light showing a distorted or unrecognizable symbol", | |
| # "Real Person Walking": "a photo of an actual person walking on the street, not a symbol", | |
| # "Real Bicycle": "a photo of a real bicycle being ridden on the road, not a symbol" | |
| } | |
| def classify_image_from_upload(input_image: Image.Image): | |
| if input_image is None: return {}, {}, "" | |
| raw_image = input_image.convert("RGB") | |
| image_processed = VIS_PROCESSORS["eval"](raw_image).unsqueeze(0).to(DEVICE) | |
| text_prompt_list = list(CLASSIFICATION_PROMPTS.values()) | |
| text_processed = [TEXT_PROCESSORS["eval"](s) for s in text_prompt_list] | |
| with torch.no_grad(): | |
| image_features = BASE_MODEL.extract_features({"image": image_processed}, mode="image").image_embeds_proj[:, 0, :] | |
| text_features = BASE_MODEL.extract_features({"text_input": text_processed}, mode="text").text_embeds_proj[:, 0, :] | |
| image_norm = F.normalize(image_features, p=2, dim=-1) | |
| text_norm = F.normalize(text_features, p=2, dim=-1) | |
| similarities = (image_norm @ text_norm.t()).squeeze() | |
| probabilities = torch.softmax(similarities, dim=-1).squeeze() | |
| sim_results = {label: sim.item() for label, sim in zip(CLASSIFICATION_PROMPTS.keys(), similarities)} | |
| prob_results_dict = {label: prob.item() for label, prob in zip(CLASSIFICATION_PROMPTS.keys(), probabilities)} | |
| sorted_probs = sorted(prob_results_dict.items(), key=lambda item: item[1], reverse=True) | |
| prob_string_output = "\n".join([f"{label}: {prob:.4f}" for label, prob in sorted_probs]) | |
| return sim_results, prob_results_dict, prob_string_output | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Image Classifier") as demo: | |
| gr.Markdown("# Cityscapes Zero-Shot Classifier") | |
| gr.Markdown("Upload an image to see the model's output as both charts and numerical points.") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| classify_button = gr.Button("Classify Image", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Cosine Similarity (Chart)") | |
| similarity_output_chart = gr.Label(label="Similarity Scores", num_top_classes=len(CLASSIFICATION_PROMPTS)) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Probability - Softmax (Chart)") | |
| probability_output_chart = gr.Label(label="Probabilities", num_top_classes=len(CLASSIFICATION_PROMPTS)) | |
| with gr.Column(): | |
| gr.Markdown("### Probability - Softmax (Points)") | |
| probability_output_points = gr.Textbox( | |
| label="Numerical Probabilities", | |
| lines=len(CLASSIFICATION_PROMPTS), | |
| interactive=False | |
| ) | |
| classify_button.click( | |
| fn=classify_image_from_upload, | |
| inputs=image_input, | |
| outputs=[similarity_output_chart, probability_output_chart, probability_output_points] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=8008, share=True) |