Spaces:
Runtime error
Runtime error
File size: 10,911 Bytes
02d3a85 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | import os
import sys
from PIL import Image
CACHE_DIR_NAME = "model_cache"
script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in locals() else os.getcwd()
local_cache_path = os.path.join(script_dir, CACHE_DIR_NAME)
os.makedirs(local_cache_path, exist_ok=True)
os.environ['HF_HOME'] = local_cache_path
os.environ['LAVIS_CACHE_ROOT'] = local_cache_path
print(f"--- CUSTOM CACHE ACTIVATED ---")
print(f"Hugging Face and LAVIS cache redirected to: {local_cache_path}")
print(f"---------------------------------")
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7,8,9'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoModel, AutoProcessor
try:
script_dir = os.path.dirname(__file__)
except NameError:
script_dir = os.getcwd()
path_to_project_root = os.path.abspath(os.path.join(script_dir, ".."))
path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS")
if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))):
print(f"Warning: Relative LAVIS path {path_to_lavis_parent_dir} not found or invalid. Trying historical path.")
path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS"
if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))):
print(f"ERROR: Could not find a valid LAVIS package location. Last tried: {path_to_lavis_parent_dir}")
sys.exit(1)
sys.path.insert(0, path_to_lavis_parent_dir)
print(f"INFO: Added {path_to_lavis_parent_dir} to sys.path for LAVIS.")
import torch.distributions.constraints as constraints
from transformers.modeling_utils import PreTrainedModel
import inspect
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
if hasattr(constraints, '_PositiveDefinite') and hasattr(constraints._PositiveDefinite, 'check'):
original_positive_definite_check = constraints._PositiveDefinite.check
def patched_positive_definite_check(self, value):
if isinstance(value, torch.Tensor) and value.is_meta: return torch.ones_like(value, dtype=torch.bool, device='cuda:6')
return original_positive_definite_check(self, value)
constraints._PositiveDefinite.check = patched_positive_definite_check
print("INFO: Patched torch.distributions.constraints._PositiveDefinite.check for meta tensors.")
if hasattr(PreTrainedModel, '_init_added_embeddings_weights_with_mean'):
original_init_added_weights = PreTrainedModel._init_added_embeddings_weights_with_mean
def patched_init_added_weights(self, new_embeddings_module, old_embeddings_module, num_added_tokens, *args, **kwargs): return
PreTrainedModel._init_added_embeddings_weights_with_mean = patched_init_added_weights
print("INFO: Patched PreTrainedModel._init_added_embeddings_weights_with_mean for meta tensors.")
print("--- Applying robust patch for load_state_dict (Vocab Mismatch & Meta Devices) ---")
_original_torch_load_state_dict = nn.Module.load_state_dict
def patched_load_state_dict(self, state_dict, strict=True, assign=False):
if isinstance(self, Blip2Qformer):
model_state_dict = self.state_dict()
for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]:
if key in state_dict and key in model_state_dict and state_dict[key].shape[0] != model_state_dict[key].shape[0]:
ckpt_tensor = state_dict[key]
model_tensor = model_state_dict[key]
model_vocab_size = model_tensor.shape[0]
state_dict[key] = ckpt_tensor.narrow(0, 0, model_vocab_size)
if any(p.is_meta for p in self.parameters()): assign = True
if 'assign' in inspect.signature(_original_torch_load_state_dict).parameters:
return _original_torch_load_state_dict(self, state_dict, strict=strict, assign=assign)
return _original_torch_load_state_dict(self, state_dict, strict=strict)
nn.Module.load_state_dict = patched_load_state_dict
print("INFO: Monkey-patched nn.Module.load_state_dict with robust handler.")
from lavis.models import load_model_and_preprocess
class MyModel(nn.Module):
def __init__(self, input_dim=256, hidden_dim=128, output_dim=256):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
DEVICE = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
GLOBAL_MODEL_PATH = "global_adapter_model.pth"
print(f"Using device: {DEVICE}")
print("Loading base BLIP-2 model (gen3_322_840)...")
BASE_MODEL, VIS_PROCESSORS, TEXT_PROCESSORS = load_model_and_preprocess(
name="blip2", model_type="gen3_322_840", is_eval=True, device=DEVICE
)
print("Base model loaded successfully.")
print("Loading fine-tuned adapter model...")
ADAPTER_MODEL = MyModel().to(DEVICE)
if os.path.exists(GLOBAL_MODEL_PATH):
ADAPTER_MODEL.load_state_dict(torch.load(GLOBAL_MODEL_PATH, map_location=DEVICE))
print(f"Successfully loaded fine-tuned adapter from '{GLOBAL_MODEL_PATH}'.")
else:
print(f"WARNING: No saved adapter model found.")
ADAPTER_MODEL.eval()
CLASSIFICATION_PROMPTS = {
# --- Traffic Light States ---
# "Red Light On": "a photo of a traffic signal with the red light illuminated",
# "Yellow Light On": "a photo of a traffic signal with the yellow light glowing",
# "Green Light On": "a photo of a traffic signal with a shining green light",
# "All Lights Off": "a photo of a traffic signal where all the lights are off",
# --- Pedestrian & Bicycle Symbols ---
# "Pedestrian Signal": "a pedestrian crosswalk signal showing the illuminated symbol of a person, either walking or standing still.",
# "Bicycle Symbol": "the illuminated symbol of a bicycle on a green traffic signal for bikes",
# --- Regulatory Signs ---
"Speed Limit Sign": "speed limit sign",
# "End of Speed Limit": "a grey or white circular sign with black diagonal lines crossing out a number, indicating the end of a speed zone",
# "Weight Limit Sign": "a circular traffic sign with a red border showing a weight limit in tonnes, often with an icon of a truck",
"No Parking Sign": "No Parking Sign",
"No Entry Sign": "No Entry Sign",
# "End of All Prohibitions": "a white or grey circular sign with multiple black diagonal lines, indicating the end of all previous prohibitions",
# --- Highway Signs ---
# "Highway Start": "a rectangular blue or green sign with a white pictogram of a motorway or highway, indicating the start of a major road",
# "Highway End": "a rectangular blue or green highway sign with a diagonal red line across it, indicating the end of the motorway",
# --- Mandatory & Warning Signs ---
# "Roundabout Sign": "a circular sign with three white arrows chasing each other in a circle, indicating a mandatory roundabout ahead",
"Road Works Warning": "Road Works Warning",
# --- Other Signs & Physical Properties ---
# "Left and U-turn Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a U-turn arrow symbol",
# "Left and Straight Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a straight-ahead arrow symbol",
# "Multibulb (LED Grid) Light": "a traffic light where the main signal is made from a grid of many small LED bulbs",
# --- Error States & Distractors ---
# "Malfunctioning Light": "a malfunctioning or broken traffic light showing a distorted or unrecognizable symbol",
# "Real Person Walking": "a photo of an actual person walking on the street, not a symbol",
# "Real Bicycle": "a photo of a real bicycle being ridden on the road, not a symbol"
}
def classify_image_from_upload(input_image: Image.Image):
if input_image is None: return {}, {}, ""
raw_image = input_image.convert("RGB")
image_processed = VIS_PROCESSORS["eval"](raw_image).unsqueeze(0).to(DEVICE)
text_prompt_list = list(CLASSIFICATION_PROMPTS.values())
text_processed = [TEXT_PROCESSORS["eval"](s) for s in text_prompt_list]
with torch.no_grad():
image_features = BASE_MODEL.extract_features({"image": image_processed}, mode="image").image_embeds_proj[:, 0, :]
text_features = BASE_MODEL.extract_features({"text_input": text_processed}, mode="text").text_embeds_proj[:, 0, :]
image_norm = F.normalize(image_features, p=2, dim=-1)
text_norm = F.normalize(text_features, p=2, dim=-1)
similarities = (image_norm @ text_norm.t()).squeeze()
probabilities = torch.softmax(similarities, dim=-1).squeeze()
sim_results = {label: sim.item() for label, sim in zip(CLASSIFICATION_PROMPTS.keys(), similarities)}
prob_results_dict = {label: prob.item() for label, prob in zip(CLASSIFICATION_PROMPTS.keys(), probabilities)}
sorted_probs = sorted(prob_results_dict.items(), key=lambda item: item[1], reverse=True)
prob_string_output = "\n".join([f"{label}: {prob:.4f}" for label, prob in sorted_probs])
return sim_results, prob_results_dict, prob_string_output
with gr.Blocks(theme=gr.themes.Soft(), title="Image Classifier") as demo:
gr.Markdown("# Cityscapes Zero-Shot Classifier")
gr.Markdown("Upload an image to see the model's output as both charts and numerical points.")
with gr.Row(variant="panel"):
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image")
classify_button = gr.Button("Classify Image", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Cosine Similarity (Chart)")
similarity_output_chart = gr.Label(label="Similarity Scores", num_top_classes=len(CLASSIFICATION_PROMPTS))
with gr.Row():
with gr.Column():
gr.Markdown("### Probability - Softmax (Chart)")
probability_output_chart = gr.Label(label="Probabilities", num_top_classes=len(CLASSIFICATION_PROMPTS))
with gr.Column():
gr.Markdown("### Probability - Softmax (Points)")
probability_output_points = gr.Textbox(
label="Numerical Probabilities",
lines=len(CLASSIFICATION_PROMPTS),
interactive=False
)
classify_button.click(
fn=classify_image_from_upload,
inputs=image_input,
outputs=[similarity_output_chart, probability_output_chart, probability_output_points]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=8008, share=True) |