File size: 10,911 Bytes
02d3a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import sys
from PIL import Image

CACHE_DIR_NAME = "model_cache"

script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in locals() else os.getcwd()

local_cache_path = os.path.join(script_dir, CACHE_DIR_NAME)

os.makedirs(local_cache_path, exist_ok=True)

os.environ['HF_HOME'] = local_cache_path
os.environ['LAVIS_CACHE_ROOT'] = local_cache_path

print(f"--- CUSTOM CACHE ACTIVATED ---")
print(f"Hugging Face and LAVIS cache redirected to: {local_cache_path}")
print(f"---------------------------------")

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7,8,9'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 

import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoModel, AutoProcessor

try:
    script_dir = os.path.dirname(__file__)
except NameError:
    script_dir = os.getcwd()
path_to_project_root = os.path.abspath(os.path.join(script_dir, ".."))
path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS")
if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))):
    print(f"Warning: Relative LAVIS path {path_to_lavis_parent_dir} not found or invalid. Trying historical path.")
    path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS"
if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))):
    print(f"ERROR: Could not find a valid LAVIS package location. Last tried: {path_to_lavis_parent_dir}")
    sys.exit(1)
sys.path.insert(0, path_to_lavis_parent_dir)
print(f"INFO: Added {path_to_lavis_parent_dir} to sys.path for LAVIS.")

import torch.distributions.constraints as constraints
from transformers.modeling_utils import PreTrainedModel
import inspect
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
if hasattr(constraints, '_PositiveDefinite') and hasattr(constraints._PositiveDefinite, 'check'):
    original_positive_definite_check = constraints._PositiveDefinite.check
    def patched_positive_definite_check(self, value):
        if isinstance(value, torch.Tensor) and value.is_meta: return torch.ones_like(value, dtype=torch.bool, device='cuda:6')
        return original_positive_definite_check(self, value)
    constraints._PositiveDefinite.check = patched_positive_definite_check
    print("INFO: Patched torch.distributions.constraints._PositiveDefinite.check for meta tensors.")
if hasattr(PreTrainedModel, '_init_added_embeddings_weights_with_mean'):
    original_init_added_weights = PreTrainedModel._init_added_embeddings_weights_with_mean
    def patched_init_added_weights(self, new_embeddings_module, old_embeddings_module, num_added_tokens, *args, **kwargs): return
    PreTrainedModel._init_added_embeddings_weights_with_mean = patched_init_added_weights
    print("INFO: Patched PreTrainedModel._init_added_embeddings_weights_with_mean for meta tensors.")
print("--- Applying robust patch for load_state_dict (Vocab Mismatch & Meta Devices) ---")
_original_torch_load_state_dict = nn.Module.load_state_dict
def patched_load_state_dict(self, state_dict, strict=True, assign=False):
    if isinstance(self, Blip2Qformer):
        model_state_dict = self.state_dict()
        for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]:
            if key in state_dict and key in model_state_dict and state_dict[key].shape[0] != model_state_dict[key].shape[0]:
                ckpt_tensor = state_dict[key]
                model_tensor = model_state_dict[key]
                model_vocab_size = model_tensor.shape[0]
                state_dict[key] = ckpt_tensor.narrow(0, 0, model_vocab_size)
    if any(p.is_meta for p in self.parameters()): assign = True
    if 'assign' in inspect.signature(_original_torch_load_state_dict).parameters:
        return _original_torch_load_state_dict(self, state_dict, strict=strict, assign=assign)
    return _original_torch_load_state_dict(self, state_dict, strict=strict)
nn.Module.load_state_dict = patched_load_state_dict
print("INFO: Monkey-patched nn.Module.load_state_dict with robust handler.")

from lavis.models import load_model_and_preprocess

class MyModel(nn.Module):
    def __init__(self, input_dim=256, hidden_dim=128, output_dim=256):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

DEVICE = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
GLOBAL_MODEL_PATH = "global_adapter_model.pth"
print(f"Using device: {DEVICE}")

print("Loading base BLIP-2 model (gen3_322_840)...")
BASE_MODEL, VIS_PROCESSORS, TEXT_PROCESSORS = load_model_and_preprocess(
    name="blip2", model_type="gen3_322_840", is_eval=True, device=DEVICE
)
print("Base model loaded successfully.")
print("Loading fine-tuned adapter model...")
ADAPTER_MODEL = MyModel().to(DEVICE)
if os.path.exists(GLOBAL_MODEL_PATH):
    ADAPTER_MODEL.load_state_dict(torch.load(GLOBAL_MODEL_PATH, map_location=DEVICE))
    print(f"Successfully loaded fine-tuned adapter from '{GLOBAL_MODEL_PATH}'.")
else:
    print(f"WARNING: No saved adapter model found.")
ADAPTER_MODEL.eval()


CLASSIFICATION_PROMPTS = {
    # --- Traffic Light States ---
    # "Red Light On": "a photo of a traffic signal with the red light illuminated",
    # "Yellow Light On": "a photo of a traffic signal with the yellow light glowing",
    # "Green Light On": "a photo of a traffic signal with a shining green light",
    # "All Lights Off": "a photo of a traffic signal where all the lights are off",
    
    # --- Pedestrian & Bicycle Symbols ---
    # "Pedestrian Signal": "a pedestrian crosswalk signal showing the illuminated symbol of a person, either walking or standing still.",
    # "Bicycle Symbol": "the illuminated symbol of a bicycle on a green traffic signal for bikes",

    # --- Regulatory Signs ---
    "Speed Limit Sign": "speed limit sign",
    # "End of Speed Limit": "a grey or white circular sign with black diagonal lines crossing out a number, indicating the end of a speed zone",
    # "Weight Limit Sign": "a circular traffic sign with a red border showing a weight limit in tonnes, often with an icon of a truck",
    "No Parking Sign": "No Parking Sign",
    "No Entry Sign": "No Entry Sign",
    # "End of All Prohibitions": "a white or grey circular sign with multiple black diagonal lines, indicating the end of all previous prohibitions",

    # --- Highway Signs ---
    # "Highway Start": "a rectangular blue or green sign with a white pictogram of a motorway or highway, indicating the start of a major road",
    # "Highway End": "a rectangular blue or green highway sign with a diagonal red line across it, indicating the end of the motorway",

    # --- Mandatory & Warning Signs ---
    # "Roundabout Sign": "a circular sign with three white arrows chasing each other in a circle, indicating a mandatory roundabout ahead",
    "Road Works Warning": "Road Works Warning",

    # --- Other Signs & Physical Properties ---
    # "Left and U-turn Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a U-turn arrow symbol",
    # "Left and Straight Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a straight-ahead arrow symbol",
    # "Multibulb (LED Grid) Light": "a traffic light where the main signal is made from a grid of many small LED bulbs",
    
    # --- Error States & Distractors ---
    # "Malfunctioning Light": "a malfunctioning or broken traffic light showing a distorted or unrecognizable symbol",
    # "Real Person Walking": "a photo of an actual person walking on the street, not a symbol",
    # "Real Bicycle": "a photo of a real bicycle being ridden on the road, not a symbol"
}

def classify_image_from_upload(input_image: Image.Image):
    if input_image is None: return {}, {}, ""
    raw_image = input_image.convert("RGB")
    image_processed = VIS_PROCESSORS["eval"](raw_image).unsqueeze(0).to(DEVICE)
    text_prompt_list = list(CLASSIFICATION_PROMPTS.values())
    text_processed = [TEXT_PROCESSORS["eval"](s) for s in text_prompt_list]
    with torch.no_grad():
        image_features = BASE_MODEL.extract_features({"image": image_processed}, mode="image").image_embeds_proj[:, 0, :]
        text_features = BASE_MODEL.extract_features({"text_input": text_processed}, mode="text").text_embeds_proj[:, 0, :]
        image_norm = F.normalize(image_features, p=2, dim=-1)
        text_norm = F.normalize(text_features, p=2, dim=-1)
        similarities = (image_norm @ text_norm.t()).squeeze()
        probabilities = torch.softmax(similarities, dim=-1).squeeze()

    sim_results = {label: sim.item() for label, sim in zip(CLASSIFICATION_PROMPTS.keys(), similarities)}
    prob_results_dict = {label: prob.item() for label, prob in zip(CLASSIFICATION_PROMPTS.keys(), probabilities)}
    
    sorted_probs = sorted(prob_results_dict.items(), key=lambda item: item[1], reverse=True)
    prob_string_output = "\n".join([f"{label}: {prob:.4f}" for label, prob in sorted_probs])
    
    return sim_results, prob_results_dict, prob_string_output

with gr.Blocks(theme=gr.themes.Soft(), title="Image Classifier") as demo:
    gr.Markdown("# Cityscapes Zero-Shot Classifier")
    gr.Markdown("Upload an image to see the model's output as both charts and numerical points.")
    
    with gr.Row(variant="panel"):
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Image")
            classify_button = gr.Button("Classify Image", variant="primary")
        
        with gr.Column(scale=2):
            gr.Markdown("### Cosine Similarity (Chart)")
            similarity_output_chart = gr.Label(label="Similarity Scores", num_top_classes=len(CLASSIFICATION_PROMPTS))
            
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Probability - Softmax (Chart)")
                    probability_output_chart = gr.Label(label="Probabilities", num_top_classes=len(CLASSIFICATION_PROMPTS))
                with gr.Column():
                    gr.Markdown("### Probability - Softmax (Points)")
                    probability_output_points = gr.Textbox(
                        label="Numerical Probabilities", 
                        lines=len(CLASSIFICATION_PROMPTS), 
                        interactive=False
                    )
            
    classify_button.click(
        fn=classify_image_from_upload, 
        inputs=image_input, 
        outputs=[similarity_output_chart, probability_output_chart, probability_output_points]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=8008, share=True)