event_retrieval / zero_shot_classification.py
sanskar753's picture
Upload folder using huggingface_hub
02d3a85 verified
import os
import sys
from PIL import Image
CACHE_DIR_NAME = "model_cache"
script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in locals() else os.getcwd()
local_cache_path = os.path.join(script_dir, CACHE_DIR_NAME)
os.makedirs(local_cache_path, exist_ok=True)
os.environ['HF_HOME'] = local_cache_path
os.environ['LAVIS_CACHE_ROOT'] = local_cache_path
print(f"--- CUSTOM CACHE ACTIVATED ---")
print(f"Hugging Face and LAVIS cache redirected to: {local_cache_path}")
print(f"---------------------------------")
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7,8,9'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoModel, AutoProcessor
try:
script_dir = os.path.dirname(__file__)
except NameError:
script_dir = os.getcwd()
path_to_project_root = os.path.abspath(os.path.join(script_dir, ".."))
path_to_lavis_parent_dir = os.path.join(path_to_project_root, "LAVIS")
if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))):
print(f"Warning: Relative LAVIS path {path_to_lavis_parent_dir} not found or invalid. Trying historical path.")
path_to_lavis_parent_dir = "/media/RTCIN7TBDriveB/Interns/RDT2/gte3kor/LAVIS"
if not (os.path.isdir(path_to_lavis_parent_dir) and os.path.isdir(os.path.join(path_to_lavis_parent_dir, "lavis"))):
print(f"ERROR: Could not find a valid LAVIS package location. Last tried: {path_to_lavis_parent_dir}")
sys.exit(1)
sys.path.insert(0, path_to_lavis_parent_dir)
print(f"INFO: Added {path_to_lavis_parent_dir} to sys.path for LAVIS.")
import torch.distributions.constraints as constraints
from transformers.modeling_utils import PreTrainedModel
import inspect
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
if hasattr(constraints, '_PositiveDefinite') and hasattr(constraints._PositiveDefinite, 'check'):
original_positive_definite_check = constraints._PositiveDefinite.check
def patched_positive_definite_check(self, value):
if isinstance(value, torch.Tensor) and value.is_meta: return torch.ones_like(value, dtype=torch.bool, device='cuda:6')
return original_positive_definite_check(self, value)
constraints._PositiveDefinite.check = patched_positive_definite_check
print("INFO: Patched torch.distributions.constraints._PositiveDefinite.check for meta tensors.")
if hasattr(PreTrainedModel, '_init_added_embeddings_weights_with_mean'):
original_init_added_weights = PreTrainedModel._init_added_embeddings_weights_with_mean
def patched_init_added_weights(self, new_embeddings_module, old_embeddings_module, num_added_tokens, *args, **kwargs): return
PreTrainedModel._init_added_embeddings_weights_with_mean = patched_init_added_weights
print("INFO: Patched PreTrainedModel._init_added_embeddings_weights_with_mean for meta tensors.")
print("--- Applying robust patch for load_state_dict (Vocab Mismatch & Meta Devices) ---")
_original_torch_load_state_dict = nn.Module.load_state_dict
def patched_load_state_dict(self, state_dict, strict=True, assign=False):
if isinstance(self, Blip2Qformer):
model_state_dict = self.state_dict()
for key in ["Qformer.cls.predictions.bias", "Qformer.cls.predictions.decoder.weight"]:
if key in state_dict and key in model_state_dict and state_dict[key].shape[0] != model_state_dict[key].shape[0]:
ckpt_tensor = state_dict[key]
model_tensor = model_state_dict[key]
model_vocab_size = model_tensor.shape[0]
state_dict[key] = ckpt_tensor.narrow(0, 0, model_vocab_size)
if any(p.is_meta for p in self.parameters()): assign = True
if 'assign' in inspect.signature(_original_torch_load_state_dict).parameters:
return _original_torch_load_state_dict(self, state_dict, strict=strict, assign=assign)
return _original_torch_load_state_dict(self, state_dict, strict=strict)
nn.Module.load_state_dict = patched_load_state_dict
print("INFO: Monkey-patched nn.Module.load_state_dict with robust handler.")
from lavis.models import load_model_and_preprocess
class MyModel(nn.Module):
def __init__(self, input_dim=256, hidden_dim=128, output_dim=256):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
DEVICE = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
GLOBAL_MODEL_PATH = "global_adapter_model.pth"
print(f"Using device: {DEVICE}")
print("Loading base BLIP-2 model (gen3_322_840)...")
BASE_MODEL, VIS_PROCESSORS, TEXT_PROCESSORS = load_model_and_preprocess(
name="blip2", model_type="gen3_322_840", is_eval=True, device=DEVICE
)
print("Base model loaded successfully.")
print("Loading fine-tuned adapter model...")
ADAPTER_MODEL = MyModel().to(DEVICE)
if os.path.exists(GLOBAL_MODEL_PATH):
ADAPTER_MODEL.load_state_dict(torch.load(GLOBAL_MODEL_PATH, map_location=DEVICE))
print(f"Successfully loaded fine-tuned adapter from '{GLOBAL_MODEL_PATH}'.")
else:
print(f"WARNING: No saved adapter model found.")
ADAPTER_MODEL.eval()
CLASSIFICATION_PROMPTS = {
# --- Traffic Light States ---
# "Red Light On": "a photo of a traffic signal with the red light illuminated",
# "Yellow Light On": "a photo of a traffic signal with the yellow light glowing",
# "Green Light On": "a photo of a traffic signal with a shining green light",
# "All Lights Off": "a photo of a traffic signal where all the lights are off",
# --- Pedestrian & Bicycle Symbols ---
# "Pedestrian Signal": "a pedestrian crosswalk signal showing the illuminated symbol of a person, either walking or standing still.",
# "Bicycle Symbol": "the illuminated symbol of a bicycle on a green traffic signal for bikes",
# --- Regulatory Signs ---
"Speed Limit Sign": "speed limit sign",
# "End of Speed Limit": "a grey or white circular sign with black diagonal lines crossing out a number, indicating the end of a speed zone",
# "Weight Limit Sign": "a circular traffic sign with a red border showing a weight limit in tonnes, often with an icon of a truck",
"No Parking Sign": "No Parking Sign",
"No Entry Sign": "No Entry Sign",
# "End of All Prohibitions": "a white or grey circular sign with multiple black diagonal lines, indicating the end of all previous prohibitions",
# --- Highway Signs ---
# "Highway Start": "a rectangular blue or green sign with a white pictogram of a motorway or highway, indicating the start of a major road",
# "Highway End": "a rectangular blue or green highway sign with a diagonal red line across it, indicating the end of the motorway",
# --- Mandatory & Warning Signs ---
# "Roundabout Sign": "a circular sign with three white arrows chasing each other in a circle, indicating a mandatory roundabout ahead",
"Road Works Warning": "Road Works Warning",
# --- Other Signs & Physical Properties ---
# "Left and U-turn Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a U-turn arrow symbol",
# "Left and Straight Sign": "a rectangular traffic sign showing both a left-turn arrow symbol and a straight-ahead arrow symbol",
# "Multibulb (LED Grid) Light": "a traffic light where the main signal is made from a grid of many small LED bulbs",
# --- Error States & Distractors ---
# "Malfunctioning Light": "a malfunctioning or broken traffic light showing a distorted or unrecognizable symbol",
# "Real Person Walking": "a photo of an actual person walking on the street, not a symbol",
# "Real Bicycle": "a photo of a real bicycle being ridden on the road, not a symbol"
}
def classify_image_from_upload(input_image: Image.Image):
if input_image is None: return {}, {}, ""
raw_image = input_image.convert("RGB")
image_processed = VIS_PROCESSORS["eval"](raw_image).unsqueeze(0).to(DEVICE)
text_prompt_list = list(CLASSIFICATION_PROMPTS.values())
text_processed = [TEXT_PROCESSORS["eval"](s) for s in text_prompt_list]
with torch.no_grad():
image_features = BASE_MODEL.extract_features({"image": image_processed}, mode="image").image_embeds_proj[:, 0, :]
text_features = BASE_MODEL.extract_features({"text_input": text_processed}, mode="text").text_embeds_proj[:, 0, :]
image_norm = F.normalize(image_features, p=2, dim=-1)
text_norm = F.normalize(text_features, p=2, dim=-1)
similarities = (image_norm @ text_norm.t()).squeeze()
probabilities = torch.softmax(similarities, dim=-1).squeeze()
sim_results = {label: sim.item() for label, sim in zip(CLASSIFICATION_PROMPTS.keys(), similarities)}
prob_results_dict = {label: prob.item() for label, prob in zip(CLASSIFICATION_PROMPTS.keys(), probabilities)}
sorted_probs = sorted(prob_results_dict.items(), key=lambda item: item[1], reverse=True)
prob_string_output = "\n".join([f"{label}: {prob:.4f}" for label, prob in sorted_probs])
return sim_results, prob_results_dict, prob_string_output
with gr.Blocks(theme=gr.themes.Soft(), title="Image Classifier") as demo:
gr.Markdown("# Cityscapes Zero-Shot Classifier")
gr.Markdown("Upload an image to see the model's output as both charts and numerical points.")
with gr.Row(variant="panel"):
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image")
classify_button = gr.Button("Classify Image", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Cosine Similarity (Chart)")
similarity_output_chart = gr.Label(label="Similarity Scores", num_top_classes=len(CLASSIFICATION_PROMPTS))
with gr.Row():
with gr.Column():
gr.Markdown("### Probability - Softmax (Chart)")
probability_output_chart = gr.Label(label="Probabilities", num_top_classes=len(CLASSIFICATION_PROMPTS))
with gr.Column():
gr.Markdown("### Probability - Softmax (Points)")
probability_output_points = gr.Textbox(
label="Numerical Probabilities",
lines=len(CLASSIFICATION_PROMPTS),
interactive=False
)
classify_button.click(
fn=classify_image_from_upload,
inputs=image_input,
outputs=[similarity_output_chart, probability_output_chart, probability_output_points]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=8008, share=True)