File size: 6,745 Bytes
ed741f2 4debe0a cf0f372 4debe0a cf0f372 4debe0a c441112 ed741f2 4debe0a ed741f2 4debe0a ed741f2 cf0f372 ed741f2 cf0f372 ed741f2 cf0f372 ed741f2 cf0f372 ed741f2 cf0f372 ed741f2 cf0f372 ed741f2 50e497f df6c486 50e497f ed741f2 cf0f372 ed741f2 cf0f372 9fef689 cf0f372 9fef689 50e497f cf0f372 9fef689 50e497f cf0f372 9fef689 cf0f372 ed741f2 8b9f879 cf0f372 8b9f879 cf0f372 8b9f879 cf0f372 ed741f2 cf0f372 4debe0a ed741f2 cf0f372 5f1d4a9 cf0f372 282248d 40e2d6d 8b9f879 40e2d6d b5397cf cf0f372 4debe0a 7b80a15 b5397cf cf0f372 8b9f879 cf0f372 40e2d6d 282248d cf0f372 282248d 4debe0a cf0f372 4debe0a cf0f372 8b9f879 cf0f372 8b9f879 cf0f372 4debe0a cf0f372 8b9f879 4debe0a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import os
import io
import asyncio
import random
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from fastapi import FastAPI, UploadFile, File, Query
from fastapi.responses import StreamingResponse
from huggingface_hub import snapshot_download, login
from transformers import (
BlipProcessor, BlipForConditionalGeneration,
ViTImageProcessor, AutoProcessor, AutoModelForCausalLM,
CLIPModel, CLIPProcessor
)
app = FastAPI(title="XAI Auditor Ensemble with CLIP Jury")
# --- Configuration & Paths ---
REPO_ID = "SaniaE/Image_Captioning_Ensemble"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODELS = {}
# Metadata for loading
MODEL_CONFIGS = {
"blip": {
"subfolder": "blip",
"proc_class": BlipProcessor,
"model_class": BlipForConditionalGeneration,
"base_path": "Salesforce/blip-image-captioning-large"
},
"vit": {
"subfolder": "vit",
"proc_classes": [ViTImageProcessor, AutoProcessor],
"model_class": AutoModelForCausalLM,
"base_paths": ["nlpconnect/vit-gpt2-image-captioning", "microsoft/git-large"]
},
"clip": {
"model_subfolder": "clip/clip_model",
"proc_subfolder": "clip/clip_processor"
}
}
@app.on_event("startup")
async def startup_event():
global MODELS
token = os.getenv("HF_Token")
if token: login(token=token)
print(f"Syncing weights from {REPO_ID}...")
local_dir = snapshot_download(repo_id=REPO_ID, token=token, local_dir="weights")
# 1. Load BLIP
cfg_b = MODEL_CONFIGS["blip"]
MODELS["blip"] = {
"model": cfg_b["model_class"].from_pretrained(os.path.join(local_dir, cfg_b["subfolder"])).to(DEVICE),
"processor": cfg_b["proc_class"].from_pretrained(cfg_b["base_path"])
}
# 2. Load ViT/GIT Ensemble
cfg_v = MODEL_CONFIGS["vit"]
MODELS["vit"] = {
"model": cfg_v["model_class"].from_pretrained(os.path.join(local_dir, cfg_v["subfolder"])).to(DEVICE),
"processor": (
cfg_v["proc_classes"][0].from_pretrained(cfg_v["base_paths"][0]),
cfg_v["proc_classes"][1].from_pretrained(cfg_v["base_paths"][1])
)
}
# 3. Load Fine-Tuned CLIP (Your Jury)
cfg_c = MODEL_CONFIGS["clip"]
MODELS["clip"] = {
"model": CLIPModel.from_pretrained(os.path.join(local_dir, cfg_c["model_subfolder"])).to(DEVICE),
"processor": CLIPProcessor.from_pretrained(os.path.join(local_dir, cfg_c["proc_subfolder"]))
}
print("All models synchronized. Auditor is active.")
# --- Utilities ---
def _generate_sync(m_name, image, temp, top_k, top_p):
m_data = MODELS[m_name]
if m_name == "vit":
i_proc, t_proc = m_data["processor"]
inputs = i_proc(images=image, return_tensors="pt").to(DEVICE)
ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p)
return t_proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
else:
proc = m_data["processor"]
inputs = proc(images=image, return_tensors="pt").to(DEVICE)
ids = m_data["model"].generate(**inputs, max_length=80, do_sample=True, temperature=temp, top_k=top_k, top_p=top_p)
return proc.batch_decode(ids, skip_special_tokens=True)[0].strip()
# --- Endpoints ---
@app.post("/generate")
async def generate_captions(
file: UploadFile = File(...),
temp: float = Query(0.8),
top_k: int = Query(50),
top_p: float = Query(0.9)
):
"""Generates 5 diverse captions using the model ensemble."""
image = Image.open(file.file).convert("RGB")
architectures = ["blip", "vit"]
selection = random.choices(architectures, k=5)
tasks = [asyncio.to_thread(_generate_sync, m, image, temp, top_k, top_p) for m in selection]
captions = await asyncio.gather(*tasks)
return {"captions": captions, "metadata": {"models_used": selection, "temp": temp}}
@app.post("/saliency")
async def get_vision_saliency(file: UploadFile = File(...)):
"""Objective Saliency: Shows what the Vision Encoder focuses on (Self-Attention)."""
image_bytes = await file.read()
orig_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
blip = MODELS["blip"]
inputs = blip["processor"](images=orig_img, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = blip["model"].vision_model(inputs.pixel_values, output_attentions=True)
attentions = outputs.attentions[-1] # Last layer
# Average heads, look at CLS token attention to patches
mask_1d = attentions[0, :, 0, 1:].mean(dim=0)
grid_size = int(np.sqrt(mask_1d.shape[-1]))
mask = mask_1d.view(grid_size, grid_size).cpu().numpy()
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
mask_img = Image.fromarray((mask * 255).astype('uint8')).resize(orig_img.size, resample=Image.BICUBIC)
mask_img = mask_img.filter(ImageFilter.GaussianBlur(radius=10))
heatmap = plt.get_cmap('magma')(np.array(mask_img)/255.0)
heatmap_img = Image.fromarray((heatmap[:, :, :3] * 255).astype('uint8')).convert("RGB")
blended = Image.blend(orig_img, heatmap_img, alpha=0.6)
buf = io.BytesIO()
blended.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
@app.post("/audit")
async def internal_debate_audit(file: UploadFile = File(...), user_prompt: str = Query(...)):
"""The CLIP-Powered Jury: Compares User Intent vs. Model Perception."""
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# 1. Model Perception
blip_caption = await asyncio.to_thread(_generate_sync, "blip", image, 0.7, 50, 0.9)
# 2. CLIP Scoring (Multimodal Alignment)
clip_m = MODELS["clip"]["model"]
clip_p = MODELS["clip"]["processor"]
inputs = clip_p(text=[user_prompt, blip_caption], images=image, return_tensors="pt", padding=True).to(DEVICE)
with torch.no_grad():
outputs = clip_m(**inputs)
probs = outputs.logits_per_image.softmax(dim=-1).cpu().numpy()[0]
u_score, m_score = float(probs[0]), float(probs[1])
# 3. Decision Logic
if u_score < 0.35:
verdict = "Perspective Divergence: Intent not grounded in image."
elif abs(u_score - m_score) < 0.15:
verdict = "Consensus: High Alignment."
else:
verdict = "Model Bias Detected."
return {
"perspectives": {"user": user_prompt, "ai": blip_caption},
"audit_scores": {"intent_grounding": round(u_score, 4), "ai_grounding": round(m_score, 4)},
"verdict": verdict
} |