mihretgold's picture
Upload backend2.py
5226295 verified
"""
Core ML logic for CLIP-based Image Retrieval with Progressive Pipeline Steering.
Stripped of Flask — designed to be imported by app.py (Gradio).
"""
import os
import json
from pathlib import Path
from typing import List, Dict, Tuple
import torch
import numpy as np
import open_clip
from PIL import Image
# Try importing Groq (optional — falls back to hardcoded attributes)
try:
from groq import Groq
except ImportError:
Groq = None
# =============================================================================
# CONFIGURATION
# =============================================================================
MODEL_NAME = "ViT-B-16"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Dataset configurations — paths relative to repo root
DATASETS = {
"flickr": {
"name": "Flickr Images",
"images_dir": Path("Images"),
"embeddings_file": Path("image_embeddings.npz"),
"description": "Flickr image dataset",
},
"stanford_dogs": {
"name": "Stanford Dogs",
"images_dir": Path("stanford_dogs_subset"),
"embeddings_file": Path("stanford_dogs_embeddings.npz"),
"description": "Stanford Dogs dataset",
},
"celeba": {
"name": "CelebA Faces",
"images_dir": Path("celeba_subset"),
"embeddings_file": Path("celeba_embeddings.npz"),
"description": "CelebA celebrity faces",
},
}
# =============================================================================
# GLOBAL STATE
# =============================================================================
_current_dataset: str = "flickr"
_model = None
_preprocess = None # OpenAI CLIP image preprocessing transform
_tokenizer = None # OpenAI CLIP text tokenizer
_image_embeddings = None # L2-normalised (for cosine retrieval)
_image_embeddings_raw = None # Raw (for SAE encode — not normalised)
_image_names = None
# Groq client (initialised lazily)
_groq_client = None
_groq_checked = False
# =============================================================================
# HELPERS
# =============================================================================
def _get_groq_client():
"""Lazily initialise the Groq client from the GROQ_API_KEY env var / secret."""
global _groq_client, _groq_checked
if _groq_checked:
return _groq_client
_groq_checked = True
api_key = os.getenv("GROQ_API_KEY")
if api_key and Groq is not None:
try:
_groq_client = Groq(api_key=api_key)
print("✅ Groq client initialised")
except Exception as exc:
print(f"⚠️ Could not init Groq: {exc}")
else:
reason = "GROQ_API_KEY not set" if not api_key else "groq package not installed"
print(f"⚠️ {reason} — using fallback attributes")
return _groq_client
# =============================================================================
# MODEL + EMBEDDINGS
# =============================================================================
def load_clip_model():
"""Load OpenAI CLIP ViT-B/16 model via open_clip (cached after first call).
Uses pretrained='openai' to load the exact same weights as the original
OpenAI CLIP package (same CDN, identical embeddings).
"""
global _model, _preprocess, _tokenizer
if _model is None or _preprocess is None:
print(f"🔄 Loading OpenAI CLIP model: {MODEL_NAME}")
_model, _, _preprocess = open_clip.create_model_and_transforms(
MODEL_NAME, pretrained="openai", device=DEVICE,
)
_tokenizer = open_clip.get_tokenizer(MODEL_NAME)
_model.eval()
for p in _model.parameters():
p.requires_grad = False
print(f"✅ CLIP loaded on {DEVICE}")
return _model, _preprocess
def get_text_embedding(text: str, normalize: bool = True) -> np.ndarray:
"""Get a single CLIP text embedding using open_clip (OpenAI weights)."""
load_clip_model() # ensure model + tokenizer are loaded
tokens = _tokenizer([text]).to(DEVICE)
with torch.no_grad():
feats = _model.encode_text(tokens).float()
if normalize:
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats[0].cpu().numpy()
def load_or_compute_embeddings() -> Tuple[np.ndarray, List[str]]:
"""Load precomputed embeddings for the current dataset.
Stores both L2-normalised (for cosine retrieval) and raw embeddings
(for SAE encode, which expects un-normalised OpenAI CLIP outputs).
"""
global _image_embeddings, _image_embeddings_raw, _image_names
if _image_embeddings is not None and _image_names is not None:
return _image_embeddings, _image_names
cfg = DATASETS[_current_dataset]
emb_path = cfg["embeddings_file"]
if emb_path.exists():
print(f"📂 Loading embeddings from {emb_path}")
data = np.load(emb_path, allow_pickle=True)
raw = data["embeddings"].astype(np.float32)
_image_embeddings_raw = raw
# L2-normalise for cosine retrieval
norms = np.linalg.norm(raw, axis=-1, keepdims=True) + 1e-8
_image_embeddings = raw / norms
_image_names = data["image_names"].tolist()
print(f"✅ Loaded {len(_image_embeddings)} embeddings")
return _image_embeddings, _image_names
# Fallback: compute on the fly (slow on CPU)
images_dir = cfg["images_dir"]
if not images_dir.exists():
raise FileNotFoundError(
f"Neither embeddings ({emb_path}) nor images ({images_dir}) found. "
"Upload at least the .npz embeddings file."
)
print("🔄 Computing embeddings from scratch (this may take a while on CPU)…")
model, preprocess = load_clip_model()
image_files = sorted(images_dir.glob("*.jpg"))[:500]
embeddings_list, valid_names = [], []
batch_size = 16
for i in range(0, len(image_files), batch_size):
batch_files = image_files[i : i + batch_size]
batch_tensors, batch_names = [], []
for p in batch_files:
try:
img = preprocess(Image.open(p).convert("RGB"))
batch_tensors.append(img)
batch_names.append(p.name)
except Exception as exc:
print(f"⚠️ Skipping {p.name}: {exc}")
if batch_tensors:
image_input = torch.stack(batch_tensors).to(DEVICE)
with torch.no_grad():
embs = model.encode_image(image_input).float()
embs = embs / embs.norm(dim=-1, keepdim=True)
embeddings_list.extend(embs.cpu().numpy())
valid_names.extend(batch_names)
_image_embeddings = np.array(embeddings_list)
_image_names = valid_names
np.savez_compressed(emb_path, embeddings=_image_embeddings, image_names=np.array(_image_names))
print(f"✅ Computed & saved {len(_image_embeddings)} embeddings")
return _image_embeddings, _image_names
# =============================================================================
# DATASET MANAGEMENT
# =============================================================================
def get_available_datasets() -> Dict[str, Dict]:
"""Return info about each dataset (name, availability)."""
info = {}
for key, cfg in DATASETS.items():
info[key] = {
"name": cfg["name"],
"description": cfg["description"],
"has_embeddings": cfg["embeddings_file"].exists(),
"has_images": cfg["images_dir"].exists(),
}
return info
def get_current_dataset() -> str:
return _current_dataset
def switch_dataset(dataset_key: str):
"""Switch active dataset and reload embeddings."""
global _current_dataset, _image_embeddings, _image_embeddings_raw, _image_names
if dataset_key == _current_dataset and _image_embeddings is not None:
return # already active
if dataset_key not in DATASETS:
raise ValueError(f"Unknown dataset: {dataset_key}")
_current_dataset = dataset_key
_image_embeddings = None
_image_embeddings_raw = None
_image_names = None
load_or_compute_embeddings()
def get_raw_embeddings() -> np.ndarray:
"""Return raw (un-normalised) image embeddings for SAE encode."""
if _image_embeddings_raw is not None:
return _image_embeddings_raw
# If raw not yet loaded, load_or_compute_embeddings will populate it
load_or_compute_embeddings()
return _image_embeddings_raw
def get_image_path(image_name: str) -> Path:
"""Resolve an image name to its full path under the current dataset."""
return DATASETS[_current_dataset]["images_dir"] / image_name
# =============================================================================
# LLM FEEDBACK
# =============================================================================
_FALLBACK_ATTRIBUTES: Dict[str, Dict] = {
# ── Stanford Dogs (7) ──
"a golden retriever": {
"positive": ["golden retriever dog", "light golden fur", "medium to long wavy coat", "floppy ears", "broad head", "friendly face"],
"negative": ["short fur dog", "dark colored fur", "pointed upright ears", "small dog breed", "cat"],
},
"dog on the beach": {
"positive": ["dog on sand", "ocean waves", "beach scenery", "wet fur", "running on beach", "sunny outdoor"],
"negative": ["indoor", "snow", "city street", "forest", "furniture"],
},
"dog looking guilty": {
"positive": ["sad", "droopy", "ashamed", "looking down", "submissive", "avoiding eye contact", "head down"],
"negative": ["happy", "playful", "energetic", "excited", "proud", "confident"],
},
"friendly looking dog": {
"positive": ["happy", "playful", "gentle", "cute", "adorable", "wagging tail", "soft eyes"],
"negative": ["aggressive", "scary", "angry", "mean", "threatening", "snarling"],
},
"aggressive looking dog": {
"positive": ["snarling", "baring teeth", "aggressive stance", "tense body", "raised hackles", "intense stare"],
"negative": ["gentle", "cute", "relaxed", "playful", "sleeping", "wagging tail"],
},
"nervous looking dog": {
"positive": ["anxious", "scared", "worried", "trembling", "wide eyes", "ears back", "cowering", "tail tucked"],
"negative": ["confident", "relaxed", "happy", "calm", "bold", "playful"],
},
"hyper active dog": {
"positive": ["running", "jumping", "playing", "energetic", "fast movement", "mouth open", "excited"],
"negative": ["sleeping", "lying down", "calm", "still", "resting", "lazy"],
},
# ── Flickr (7) ──
"a person riding a bicycle": {
"positive": ["cyclist", "bicycle", "riding bike", "pedaling", "wheels", "helmet"],
"negative": ["walking", "car", "sitting on bench", "standing still", "motorcycle"],
},
"a dog playing": {
"positive": ["playful dog", "running", "fetching", "jumping", "toy in mouth", "energetic", "outdoor play"],
"negative": ["sleeping", "sitting still", "calm", "resting", "sad"],
},
"an exciting action scene": {
"positive": ["sports", "jumping", "running", "movement", "dynamic action", "fast motion", "athletic"],
"negative": ["standing", "sitting", "calm", "still", "resting", "portrait"],
},
"a joyful moment": {
"positive": ["smiling", "laughing", "celebrating", "happy faces", "hugging", "arms raised", "bright colors"],
"negative": ["sad", "crying", "angry", "alone", "dark", "serious face"],
},
"a kid having fun": {
"positive": ["child playing", "laughing kid", "outdoor play", "toys", "smiling child", "running", "playground"],
"negative": ["adult", "serious", "crying", "sitting still", "elderly", "office"],
},
"peaceful scene": {
"positive": ["calm water", "sunset", "nature", "quiet", "serene landscape", "soft light", "still"],
"negative": ["crowded", "noisy", "urban", "traffic", "construction", "chaotic"],
},
"a photo with motion": {
"positive": ["motion blur", "running", "moving fast", "dynamic", "action", "speed", "blurred background"],
"negative": ["still", "static", "portrait", "posed", "standing", "sharp focus"],
},
# ── CelebA (8) ──
"wearing eyeglasses": {
"positive": ["eyeglasses", "spectacles", "glasses frames", "lenses on face", "reading glasses"],
"negative": ["no glasses", "bare face", "sunglasses", "contact lenses"],
},
"a person smiling": {
"positive": ["smiling", "teeth showing", "happy expression", "grin", "cheerful face", "bright eyes"],
"negative": ["frowning", "serious face", "sad", "neutral expression", "angry"],
},
"looking guilty": {
"positive": ["worried", "nervous", "serious", "sad", "secretive", "uncomfortable", "looking away", "avoiding eye contact"],
"negative": ["smiling", "happy", "confident", "laughing", "relaxed", "proud"],
},
"looking happy": {
"positive": ["smiling", "laughing", "joyful", "bright eyes", "cheerful", "grinning", "radiant"],
"negative": ["sad", "frowning", "crying", "angry", "tired", "bored"],
},
"looking sad": {
"positive": ["frowning", "tearful", "downcast eyes", "drooping mouth", "somber", "melancholy", "looking down"],
"negative": ["smiling", "happy", "laughing", "excited", "cheerful", "energetic"],
},
"looking suspicious": {
"positive": ["narrowed eyes", "side glance", "raised eyebrow", "squinting", "tense jaw", "furrowed brow"],
"negative": ["smiling", "relaxed", "open face", "friendly", "trusting", "wide eyes"],
},
"looking tired": {
"positive": ["droopy eyes", "yawning", "dark circles", "half closed eyes", "slouching", "exhausted"],
"negative": ["alert", "energetic", "wide awake", "bright eyes", "smiling", "active"],
},
"looking confident": {
"positive": ["upright posture", "direct eye contact", "chin up", "strong stance", "composed", "assertive"],
"negative": ["slouching", "looking down", "nervous", "fidgeting", "shy", "uncertain"],
},
}
def _fallback_feedback(query: str) -> Dict:
"""Query-specific fallback when Groq is unavailable or rate-limited."""
key = query.strip().lower()
for fb_key, fb_val in _FALLBACK_ATTRIBUTES.items():
if fb_key.lower() == key:
return {
"positive": [{"attribute": a, "weight": 0.8, "rationale": "fallback"}
for a in fb_val["positive"]],
"negative": [{"attribute": a, "weight": 0.6, "rationale": "fallback"}
for a in fb_val["negative"]],
"alpha": 0.4,
"beta": 0.4,
}
# Generic fallback for unknown queries
return {
"positive": [
{"attribute": query, "weight": 0.9, "rationale": "Original query"},
],
"negative": [],
"alpha": 0.4,
"beta": 0.4,
}
def generate_feedback_with_weights(query: str) -> Dict:
"""Generate per-attribute weights from an LLM (or fallback)."""
client = _get_groq_client()
if client is None:
return _fallback_feedback(query)
prompt = f"""You are an expert at decomposing image-search queries into visually grounded CLIP steering attributes.
CONTEXT
- We steer a CLIP ViT-B/16 query embedding by adding positive attribute embeddings and subtracting negative ones.
- Attributes must be OBSERVABLE VISUAL properties — things you can literally SEE in a photograph.
- Each attribute is 1-5 words. No abstract concepts. No full sentences.
CRITICAL RULES FOR GOOD ATTRIBUTES
- For EMOTIONS or SUBJECTIVE states: describe the BODY LANGUAGE, not the emotion word.
BAD: "guilty expression" (CLIP doesn't understand abstract emotions well)
GOOD: "looking down", "avoiding eye contact", "head lowered", "droopy ears"
- For PHYSICAL descriptions: use concrete visual details.
BAD: "beautiful dog" (vague, subjective)
GOOD: "light golden fur", "floppy ears", "broad head", "medium to long wavy coat"
- For SCENES: describe observable elements.
BAD: "exciting scene" (abstract)
GOOD: "jumping", "running", "movement", "sports"
- NEGATIVES must be the visual OPPOSITE or a common CLIP confusion.
They should suppress what CLIP tends to retrieve incorrectly for this query.
FEW-SHOT EXAMPLES (from human experts):
Query: "a dog looking guilty"
positive: ["sad", "droopy", "ashamed", "looking down", "submissive", "avoiding eye contact", "head down"]
negative: ["happy", "playful", "energetic", "excited", "proud", "confident"]
Query: "a nervous-looking dog"
positive: ["anxious", "scared", "worried", "trembling", "wide eyes", "ears back"]
negative: ["confident", "relaxed", "happy", "calm", "bold"]
Query: "a golden retriever"
positive: ["golden retriever dog", "light golden fur", "medium to long wavy coat", "floppy ears", "broad head"]
negative: ["short fur dog", "dark colored fur", "pointed upright ears", "small dog breed"]
Query: "a person looking guilty"
positive: ["worried", "nervous", "serious", "sad", "secretive", "uncomfortable"]
negative: ["smiling", "happy", "confident", "laughing", "relaxed"]
Query: "an exciting action scene"
positive: ["sports", "jumping", "running", "movement"]
negative: ["standing", "sitting", "calm", "still"]
Query: "a friendly-looking dog"
positive: ["happy", "playful", "gentle", "cute", "adorable", "wagging tail"]
negative: ["aggressive", "scary", "angry", "mean", "threatening"]
NOW GENERATE FOR THIS QUERY:
USER QUERY: "{query}"
INSTRUCTIONS:
1. Generate 5-8 POSITIVE attributes (observable visual cues that define what the user wants).
- The first 1-2 should be the most important (weight 0.8-1.0).
- The rest are supporting visual details (weight 0.4-0.7).
- For subjective/emotional queries: describe body language, posture, facial features.
2. Generate 5-8 NEGATIVE attributes (visual opposites or CLIP confusions to suppress).
- Weight 0.5-0.8 for strong opposites, 0.3-0.5 for subtle.
3. Set alpha=0.4 and beta=0.4 (safe defaults). Only increase to 0.45 for very specific queries.
Return ONLY a JSON object (no markdown, no explanation):
{{
"positive": [
{{"attribute": "short visual phrase", "weight": 0.9, "rationale": "why"}},
{{"attribute": "another phrase", "weight": 0.6, "rationale": "why"}}
],
"negative": [
{{"attribute": "short visual phrase", "weight": 0.7, "rationale": "why"}},
{{"attribute": "another phrase", "weight": 0.5, "rationale": "why"}}
],
"alpha": 0.4,
"beta": 0.4
}}"""
try:
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
{
"role": "system",
"content": (
"You are a vision-language retrieval expert specialising in CLIP embedding steering. "
"You decompose queries into OBSERVABLE VISUAL attributes — body language, textures, "
"colours, postures, spatial arrangements — never abstract or subjective words. "
"Follow the few-shot examples closely. Return ONLY valid JSON."
),
},
{"role": "user", "content": prompt},
],
temperature=0.3,
max_tokens=1000,
)
content = response.choices[0].message.content.strip()
# Strip markdown fences
if content.startswith("```json"):
content = content[7:]
if content.startswith("```"):
content = content[3:]
if content.endswith("```"):
content = content[:-3]
content = content.strip()
feedback = json.loads(content)
# Validate structure
for attr_type in ("positive", "negative"):
if attr_type not in feedback or not isinstance(feedback[attr_type], list):
feedback[attr_type] = []
for attr in feedback[attr_type]:
attr.setdefault("weight", 0.5)
attr["weight"] = max(0.0, min(1.0, float(attr["weight"])))
attr.setdefault("rationale", "")
feedback.setdefault("alpha", 0.4)
feedback.setdefault("beta", 0.4)
feedback["alpha"] = max(0.1, min(0.8, float(feedback["alpha"])))
feedback["beta"] = max(0.1, min(0.8, float(feedback["beta"])))
return feedback
except Exception as exc:
print(f"⚠️ LLM error: {exc}")
return _fallback_feedback(query)
# =============================================================================
# STEERING METHODS
# =============================================================================
def baseline_retrieval(
query_emb: np.ndarray,
embeddings: np.ndarray,
image_names: List[str],
top_k: int = 5,
) -> List[Dict]:
"""Pure CLIP retrieval without steering."""
sims = embeddings @ query_emb
top_idx = np.argsort(sims)[::-1][:top_k]
return [{"image_name": image_names[i], "similarity": float(sims[i])} for i in top_idx]
def linear_steering(
query_emb: np.ndarray,
positive_attrs: List[Dict],
negative_attrs: List[Dict],
alpha: float = 0.4,
beta: float = 0.4,
) -> np.ndarray:
"""q' = q + α·Σ(w_i·p_i) - β·Σ(w_j·n_j)"""
steered = query_emb.copy()
for attr in positive_attrs:
direction = get_text_embedding(f"a photo of {attr['attribute']}")
steered += alpha * attr.get("weight", 1.0) * direction
for attr in negative_attrs:
direction = get_text_embedding(f"a photo of {attr['attribute']}")
steered -= beta * attr.get("weight", 1.0) * direction
return steered / np.linalg.norm(steered)
def subspace_steering(
query_emb: np.ndarray,
positive_attrs: List[Dict],
negative_attrs: List[Dict],
) -> np.ndarray:
"""Contrastive subspace (centroid-based) steering."""
pos_embs = [get_text_embedding(f"a photo of {a['attribute']}") for a in positive_attrs]
neg_embs = [get_text_embedding(f"a photo of {a['attribute']}") for a in negative_attrs]
if not pos_embs or not neg_embs:
return query_emb
direction = np.mean(pos_embs, axis=0) - np.mean(neg_embs, axis=0)
direction /= np.linalg.norm(direction)
steered = query_emb + 0.5 * direction
return steered / np.linalg.norm(steered)
def energy_based_steering(
query_emb: np.ndarray,
positive_attrs: List[Dict],
negative_attrs: List[Dict],
n_steps: int = 30,
lr: float = 0.05,
) -> np.ndarray:
"""Energy-based steering via gradient descent."""
steered = query_emb.copy()
original = query_emb.copy()
pos_embs = [(get_text_embedding(f"a photo of {a['attribute']}"), a.get("weight", 1.0)) for a in positive_attrs]
neg_embs = [(get_text_embedding(f"a photo of {a['attribute']}"), a.get("weight", 1.0)) for a in negative_attrs]
for _ in range(n_steps):
grad = np.zeros_like(steered)
for emb, w in pos_embs:
grad -= w * (emb - steered)
for emb, w in neg_embs:
grad += 0.5 * w * (emb - steered)
grad += 0.1 * (steered - original)
steered -= lr * grad
return steered / np.linalg.norm(steered)
def weighted_energy_steering(
query_emb: np.ndarray,
positive_attrs: List[Dict],
negative_attrs: List[Dict],
n_steps: int = 30,
lr: float = 0.05,
) -> np.ndarray:
"""Weighted energy steering with normalised per-attribute weights."""
steered = query_emb.copy()
original = query_emb.copy()
pos_w = [a.get("weight", 1.0) for a in positive_attrs]
neg_w = [a.get("weight", 1.0) for a in negative_attrs]
if sum(pos_w) > 0:
pos_w = [w / sum(pos_w) * len(pos_w) for w in pos_w]
if sum(neg_w) > 0:
neg_w = [w / sum(neg_w) * len(neg_w) for w in neg_w]
pos_embs = [get_text_embedding(f"a photo of {a['attribute']}") for a in positive_attrs]
neg_embs = [get_text_embedding(f"a photo of {a['attribute']}") for a in negative_attrs]
for _ in range(n_steps):
grad = np.zeros_like(steered)
for emb, w in zip(pos_embs, pos_w):
grad -= w * (emb - steered)
for emb, w in zip(neg_embs, neg_w):
grad += w * (emb - steered)
grad += 0.1 * (steered - original)
steered -= lr * grad
return steered / np.linalg.norm(steered)