| """ |
| Core ML logic for CLIP-based Image Retrieval with Progressive Pipeline Steering. |
| Stripped of Flask — designed to be imported by app.py (Gradio). |
| """ |
|
|
| import os |
| import json |
| from pathlib import Path |
| from typing import List, Dict, Tuple |
|
|
| import torch |
| import numpy as np |
| import open_clip |
| from PIL import Image |
|
|
| |
| try: |
| from groq import Groq |
| except ImportError: |
| Groq = None |
|
|
| |
| |
| |
|
|
| MODEL_NAME = "ViT-B-16" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| DATASETS = { |
| "flickr": { |
| "name": "Flickr Images", |
| "images_dir": Path("Images"), |
| "embeddings_file": Path("image_embeddings.npz"), |
| "description": "Flickr image dataset", |
| }, |
| "stanford_dogs": { |
| "name": "Stanford Dogs", |
| "images_dir": Path("stanford_dogs_subset"), |
| "embeddings_file": Path("stanford_dogs_embeddings.npz"), |
| "description": "Stanford Dogs dataset", |
| }, |
| "celeba": { |
| "name": "CelebA Faces", |
| "images_dir": Path("celeba_subset"), |
| "embeddings_file": Path("celeba_embeddings.npz"), |
| "description": "CelebA celebrity faces", |
| }, |
| } |
|
|
| |
| |
| |
|
|
| _current_dataset: str = "flickr" |
| _model = None |
| _preprocess = None |
| _tokenizer = None |
| _image_embeddings = None |
| _image_embeddings_raw = None |
| _image_names = None |
|
|
| |
| _groq_client = None |
| _groq_checked = False |
|
|
|
|
| |
| |
| |
|
|
| def _get_groq_client(): |
| """Lazily initialise the Groq client from the GROQ_API_KEY env var / secret.""" |
| global _groq_client, _groq_checked |
| if _groq_checked: |
| return _groq_client |
| _groq_checked = True |
| api_key = os.getenv("GROQ_API_KEY") |
| if api_key and Groq is not None: |
| try: |
| _groq_client = Groq(api_key=api_key) |
| print("✅ Groq client initialised") |
| except Exception as exc: |
| print(f"⚠️ Could not init Groq: {exc}") |
| else: |
| reason = "GROQ_API_KEY not set" if not api_key else "groq package not installed" |
| print(f"⚠️ {reason} — using fallback attributes") |
| return _groq_client |
|
|
|
|
| |
| |
| |
|
|
| def load_clip_model(): |
| """Load OpenAI CLIP ViT-B/16 model via open_clip (cached after first call). |
| |
| Uses pretrained='openai' to load the exact same weights as the original |
| OpenAI CLIP package (same CDN, identical embeddings). |
| """ |
| global _model, _preprocess, _tokenizer |
| if _model is None or _preprocess is None: |
| print(f"🔄 Loading OpenAI CLIP model: {MODEL_NAME}") |
| _model, _, _preprocess = open_clip.create_model_and_transforms( |
| MODEL_NAME, pretrained="openai", device=DEVICE, |
| ) |
| _tokenizer = open_clip.get_tokenizer(MODEL_NAME) |
| _model.eval() |
| for p in _model.parameters(): |
| p.requires_grad = False |
| print(f"✅ CLIP loaded on {DEVICE}") |
| return _model, _preprocess |
|
|
|
|
| def get_text_embedding(text: str, normalize: bool = True) -> np.ndarray: |
| """Get a single CLIP text embedding using open_clip (OpenAI weights).""" |
| load_clip_model() |
| tokens = _tokenizer([text]).to(DEVICE) |
| with torch.no_grad(): |
| feats = _model.encode_text(tokens).float() |
| if normalize: |
| feats = feats / feats.norm(dim=-1, keepdim=True) |
| return feats[0].cpu().numpy() |
|
|
|
|
| def load_or_compute_embeddings() -> Tuple[np.ndarray, List[str]]: |
| """Load precomputed embeddings for the current dataset. |
| |
| Stores both L2-normalised (for cosine retrieval) and raw embeddings |
| (for SAE encode, which expects un-normalised OpenAI CLIP outputs). |
| """ |
| global _image_embeddings, _image_embeddings_raw, _image_names |
|
|
| if _image_embeddings is not None and _image_names is not None: |
| return _image_embeddings, _image_names |
|
|
| cfg = DATASETS[_current_dataset] |
| emb_path = cfg["embeddings_file"] |
|
|
| if emb_path.exists(): |
| print(f"📂 Loading embeddings from {emb_path}") |
| data = np.load(emb_path, allow_pickle=True) |
| raw = data["embeddings"].astype(np.float32) |
| _image_embeddings_raw = raw |
| |
| norms = np.linalg.norm(raw, axis=-1, keepdims=True) + 1e-8 |
| _image_embeddings = raw / norms |
| _image_names = data["image_names"].tolist() |
| print(f"✅ Loaded {len(_image_embeddings)} embeddings") |
| return _image_embeddings, _image_names |
|
|
| |
| images_dir = cfg["images_dir"] |
| if not images_dir.exists(): |
| raise FileNotFoundError( |
| f"Neither embeddings ({emb_path}) nor images ({images_dir}) found. " |
| "Upload at least the .npz embeddings file." |
| ) |
|
|
| print("🔄 Computing embeddings from scratch (this may take a while on CPU)…") |
| model, preprocess = load_clip_model() |
| image_files = sorted(images_dir.glob("*.jpg"))[:500] |
|
|
| embeddings_list, valid_names = [], [] |
| batch_size = 16 |
| for i in range(0, len(image_files), batch_size): |
| batch_files = image_files[i : i + batch_size] |
| batch_tensors, batch_names = [], [] |
| for p in batch_files: |
| try: |
| img = preprocess(Image.open(p).convert("RGB")) |
| batch_tensors.append(img) |
| batch_names.append(p.name) |
| except Exception as exc: |
| print(f"⚠️ Skipping {p.name}: {exc}") |
| if batch_tensors: |
| image_input = torch.stack(batch_tensors).to(DEVICE) |
| with torch.no_grad(): |
| embs = model.encode_image(image_input).float() |
| embs = embs / embs.norm(dim=-1, keepdim=True) |
| embeddings_list.extend(embs.cpu().numpy()) |
| valid_names.extend(batch_names) |
|
|
| _image_embeddings = np.array(embeddings_list) |
| _image_names = valid_names |
| np.savez_compressed(emb_path, embeddings=_image_embeddings, image_names=np.array(_image_names)) |
| print(f"✅ Computed & saved {len(_image_embeddings)} embeddings") |
| return _image_embeddings, _image_names |
|
|
|
|
| |
| |
| |
|
|
| def get_available_datasets() -> Dict[str, Dict]: |
| """Return info about each dataset (name, availability).""" |
| info = {} |
| for key, cfg in DATASETS.items(): |
| info[key] = { |
| "name": cfg["name"], |
| "description": cfg["description"], |
| "has_embeddings": cfg["embeddings_file"].exists(), |
| "has_images": cfg["images_dir"].exists(), |
| } |
| return info |
|
|
|
|
| def get_current_dataset() -> str: |
| return _current_dataset |
|
|
|
|
| def switch_dataset(dataset_key: str): |
| """Switch active dataset and reload embeddings.""" |
| global _current_dataset, _image_embeddings, _image_embeddings_raw, _image_names |
|
|
| if dataset_key == _current_dataset and _image_embeddings is not None: |
| return |
|
|
| if dataset_key not in DATASETS: |
| raise ValueError(f"Unknown dataset: {dataset_key}") |
|
|
| _current_dataset = dataset_key |
| _image_embeddings = None |
| _image_embeddings_raw = None |
| _image_names = None |
| load_or_compute_embeddings() |
|
|
|
|
| def get_raw_embeddings() -> np.ndarray: |
| """Return raw (un-normalised) image embeddings for SAE encode.""" |
| if _image_embeddings_raw is not None: |
| return _image_embeddings_raw |
| |
| load_or_compute_embeddings() |
| return _image_embeddings_raw |
|
|
|
|
| def get_image_path(image_name: str) -> Path: |
| """Resolve an image name to its full path under the current dataset.""" |
| return DATASETS[_current_dataset]["images_dir"] / image_name |
|
|
|
|
| |
| |
| |
|
|
| _FALLBACK_ATTRIBUTES: Dict[str, Dict] = { |
| |
| "a golden retriever": { |
| "positive": ["golden retriever dog", "light golden fur", "medium to long wavy coat", "floppy ears", "broad head", "friendly face"], |
| "negative": ["short fur dog", "dark colored fur", "pointed upright ears", "small dog breed", "cat"], |
| }, |
| "dog on the beach": { |
| "positive": ["dog on sand", "ocean waves", "beach scenery", "wet fur", "running on beach", "sunny outdoor"], |
| "negative": ["indoor", "snow", "city street", "forest", "furniture"], |
| }, |
| "dog looking guilty": { |
| "positive": ["sad", "droopy", "ashamed", "looking down", "submissive", "avoiding eye contact", "head down"], |
| "negative": ["happy", "playful", "energetic", "excited", "proud", "confident"], |
| }, |
| "friendly looking dog": { |
| "positive": ["happy", "playful", "gentle", "cute", "adorable", "wagging tail", "soft eyes"], |
| "negative": ["aggressive", "scary", "angry", "mean", "threatening", "snarling"], |
| }, |
| "aggressive looking dog": { |
| "positive": ["snarling", "baring teeth", "aggressive stance", "tense body", "raised hackles", "intense stare"], |
| "negative": ["gentle", "cute", "relaxed", "playful", "sleeping", "wagging tail"], |
| }, |
| "nervous looking dog": { |
| "positive": ["anxious", "scared", "worried", "trembling", "wide eyes", "ears back", "cowering", "tail tucked"], |
| "negative": ["confident", "relaxed", "happy", "calm", "bold", "playful"], |
| }, |
| "hyper active dog": { |
| "positive": ["running", "jumping", "playing", "energetic", "fast movement", "mouth open", "excited"], |
| "negative": ["sleeping", "lying down", "calm", "still", "resting", "lazy"], |
| }, |
| |
| "a person riding a bicycle": { |
| "positive": ["cyclist", "bicycle", "riding bike", "pedaling", "wheels", "helmet"], |
| "negative": ["walking", "car", "sitting on bench", "standing still", "motorcycle"], |
| }, |
| "a dog playing": { |
| "positive": ["playful dog", "running", "fetching", "jumping", "toy in mouth", "energetic", "outdoor play"], |
| "negative": ["sleeping", "sitting still", "calm", "resting", "sad"], |
| }, |
| "an exciting action scene": { |
| "positive": ["sports", "jumping", "running", "movement", "dynamic action", "fast motion", "athletic"], |
| "negative": ["standing", "sitting", "calm", "still", "resting", "portrait"], |
| }, |
| "a joyful moment": { |
| "positive": ["smiling", "laughing", "celebrating", "happy faces", "hugging", "arms raised", "bright colors"], |
| "negative": ["sad", "crying", "angry", "alone", "dark", "serious face"], |
| }, |
| "a kid having fun": { |
| "positive": ["child playing", "laughing kid", "outdoor play", "toys", "smiling child", "running", "playground"], |
| "negative": ["adult", "serious", "crying", "sitting still", "elderly", "office"], |
| }, |
| "peaceful scene": { |
| "positive": ["calm water", "sunset", "nature", "quiet", "serene landscape", "soft light", "still"], |
| "negative": ["crowded", "noisy", "urban", "traffic", "construction", "chaotic"], |
| }, |
| "a photo with motion": { |
| "positive": ["motion blur", "running", "moving fast", "dynamic", "action", "speed", "blurred background"], |
| "negative": ["still", "static", "portrait", "posed", "standing", "sharp focus"], |
| }, |
| |
| "wearing eyeglasses": { |
| "positive": ["eyeglasses", "spectacles", "glasses frames", "lenses on face", "reading glasses"], |
| "negative": ["no glasses", "bare face", "sunglasses", "contact lenses"], |
| }, |
| "a person smiling": { |
| "positive": ["smiling", "teeth showing", "happy expression", "grin", "cheerful face", "bright eyes"], |
| "negative": ["frowning", "serious face", "sad", "neutral expression", "angry"], |
| }, |
| "looking guilty": { |
| "positive": ["worried", "nervous", "serious", "sad", "secretive", "uncomfortable", "looking away", "avoiding eye contact"], |
| "negative": ["smiling", "happy", "confident", "laughing", "relaxed", "proud"], |
| }, |
| "looking happy": { |
| "positive": ["smiling", "laughing", "joyful", "bright eyes", "cheerful", "grinning", "radiant"], |
| "negative": ["sad", "frowning", "crying", "angry", "tired", "bored"], |
| }, |
| "looking sad": { |
| "positive": ["frowning", "tearful", "downcast eyes", "drooping mouth", "somber", "melancholy", "looking down"], |
| "negative": ["smiling", "happy", "laughing", "excited", "cheerful", "energetic"], |
| }, |
| "looking suspicious": { |
| "positive": ["narrowed eyes", "side glance", "raised eyebrow", "squinting", "tense jaw", "furrowed brow"], |
| "negative": ["smiling", "relaxed", "open face", "friendly", "trusting", "wide eyes"], |
| }, |
| "looking tired": { |
| "positive": ["droopy eyes", "yawning", "dark circles", "half closed eyes", "slouching", "exhausted"], |
| "negative": ["alert", "energetic", "wide awake", "bright eyes", "smiling", "active"], |
| }, |
| "looking confident": { |
| "positive": ["upright posture", "direct eye contact", "chin up", "strong stance", "composed", "assertive"], |
| "negative": ["slouching", "looking down", "nervous", "fidgeting", "shy", "uncertain"], |
| }, |
| } |
|
|
|
|
| def _fallback_feedback(query: str) -> Dict: |
| """Query-specific fallback when Groq is unavailable or rate-limited.""" |
| key = query.strip().lower() |
| for fb_key, fb_val in _FALLBACK_ATTRIBUTES.items(): |
| if fb_key.lower() == key: |
| return { |
| "positive": [{"attribute": a, "weight": 0.8, "rationale": "fallback"} |
| for a in fb_val["positive"]], |
| "negative": [{"attribute": a, "weight": 0.6, "rationale": "fallback"} |
| for a in fb_val["negative"]], |
| "alpha": 0.4, |
| "beta": 0.4, |
| } |
| |
| return { |
| "positive": [ |
| {"attribute": query, "weight": 0.9, "rationale": "Original query"}, |
| ], |
| "negative": [], |
| "alpha": 0.4, |
| "beta": 0.4, |
| } |
|
|
|
|
| def generate_feedback_with_weights(query: str) -> Dict: |
| """Generate per-attribute weights from an LLM (or fallback).""" |
| client = _get_groq_client() |
| if client is None: |
| return _fallback_feedback(query) |
|
|
| prompt = f"""You are an expert at decomposing image-search queries into visually grounded CLIP steering attributes. |
| |
| CONTEXT |
| - We steer a CLIP ViT-B/16 query embedding by adding positive attribute embeddings and subtracting negative ones. |
| - Attributes must be OBSERVABLE VISUAL properties — things you can literally SEE in a photograph. |
| - Each attribute is 1-5 words. No abstract concepts. No full sentences. |
| |
| CRITICAL RULES FOR GOOD ATTRIBUTES |
| - For EMOTIONS or SUBJECTIVE states: describe the BODY LANGUAGE, not the emotion word. |
| BAD: "guilty expression" (CLIP doesn't understand abstract emotions well) |
| GOOD: "looking down", "avoiding eye contact", "head lowered", "droopy ears" |
| - For PHYSICAL descriptions: use concrete visual details. |
| BAD: "beautiful dog" (vague, subjective) |
| GOOD: "light golden fur", "floppy ears", "broad head", "medium to long wavy coat" |
| - For SCENES: describe observable elements. |
| BAD: "exciting scene" (abstract) |
| GOOD: "jumping", "running", "movement", "sports" |
| - NEGATIVES must be the visual OPPOSITE or a common CLIP confusion. |
| They should suppress what CLIP tends to retrieve incorrectly for this query. |
| |
| FEW-SHOT EXAMPLES (from human experts): |
| |
| Query: "a dog looking guilty" |
| positive: ["sad", "droopy", "ashamed", "looking down", "submissive", "avoiding eye contact", "head down"] |
| negative: ["happy", "playful", "energetic", "excited", "proud", "confident"] |
| |
| Query: "a nervous-looking dog" |
| positive: ["anxious", "scared", "worried", "trembling", "wide eyes", "ears back"] |
| negative: ["confident", "relaxed", "happy", "calm", "bold"] |
| |
| Query: "a golden retriever" |
| positive: ["golden retriever dog", "light golden fur", "medium to long wavy coat", "floppy ears", "broad head"] |
| negative: ["short fur dog", "dark colored fur", "pointed upright ears", "small dog breed"] |
| |
| Query: "a person looking guilty" |
| positive: ["worried", "nervous", "serious", "sad", "secretive", "uncomfortable"] |
| negative: ["smiling", "happy", "confident", "laughing", "relaxed"] |
| |
| Query: "an exciting action scene" |
| positive: ["sports", "jumping", "running", "movement"] |
| negative: ["standing", "sitting", "calm", "still"] |
| |
| Query: "a friendly-looking dog" |
| positive: ["happy", "playful", "gentle", "cute", "adorable", "wagging tail"] |
| negative: ["aggressive", "scary", "angry", "mean", "threatening"] |
| |
| NOW GENERATE FOR THIS QUERY: |
| |
| USER QUERY: "{query}" |
| |
| INSTRUCTIONS: |
| 1. Generate 5-8 POSITIVE attributes (observable visual cues that define what the user wants). |
| - The first 1-2 should be the most important (weight 0.8-1.0). |
| - The rest are supporting visual details (weight 0.4-0.7). |
| - For subjective/emotional queries: describe body language, posture, facial features. |
| 2. Generate 5-8 NEGATIVE attributes (visual opposites or CLIP confusions to suppress). |
| - Weight 0.5-0.8 for strong opposites, 0.3-0.5 for subtle. |
| 3. Set alpha=0.4 and beta=0.4 (safe defaults). Only increase to 0.45 for very specific queries. |
| |
| Return ONLY a JSON object (no markdown, no explanation): |
| {{ |
| "positive": [ |
| {{"attribute": "short visual phrase", "weight": 0.9, "rationale": "why"}}, |
| {{"attribute": "another phrase", "weight": 0.6, "rationale": "why"}} |
| ], |
| "negative": [ |
| {{"attribute": "short visual phrase", "weight": 0.7, "rationale": "why"}}, |
| {{"attribute": "another phrase", "weight": 0.5, "rationale": "why"}} |
| ], |
| "alpha": 0.4, |
| "beta": 0.4 |
| }}""" |
|
|
| try: |
| response = client.chat.completions.create( |
| model="llama-3.3-70b-versatile", |
| messages=[ |
| { |
| "role": "system", |
| "content": ( |
| "You are a vision-language retrieval expert specialising in CLIP embedding steering. " |
| "You decompose queries into OBSERVABLE VISUAL attributes — body language, textures, " |
| "colours, postures, spatial arrangements — never abstract or subjective words. " |
| "Follow the few-shot examples closely. Return ONLY valid JSON." |
| ), |
| }, |
| {"role": "user", "content": prompt}, |
| ], |
| temperature=0.3, |
| max_tokens=1000, |
| ) |
| content = response.choices[0].message.content.strip() |
|
|
| |
| if content.startswith("```json"): |
| content = content[7:] |
| if content.startswith("```"): |
| content = content[3:] |
| if content.endswith("```"): |
| content = content[:-3] |
| content = content.strip() |
|
|
| feedback = json.loads(content) |
|
|
| |
| for attr_type in ("positive", "negative"): |
| if attr_type not in feedback or not isinstance(feedback[attr_type], list): |
| feedback[attr_type] = [] |
| for attr in feedback[attr_type]: |
| attr.setdefault("weight", 0.5) |
| attr["weight"] = max(0.0, min(1.0, float(attr["weight"]))) |
| attr.setdefault("rationale", "") |
|
|
| feedback.setdefault("alpha", 0.4) |
| feedback.setdefault("beta", 0.4) |
| feedback["alpha"] = max(0.1, min(0.8, float(feedback["alpha"]))) |
| feedback["beta"] = max(0.1, min(0.8, float(feedback["beta"]))) |
| return feedback |
|
|
| except Exception as exc: |
| print(f"⚠️ LLM error: {exc}") |
| return _fallback_feedback(query) |
|
|
|
|
| |
| |
| |
|
|
| def baseline_retrieval( |
| query_emb: np.ndarray, |
| embeddings: np.ndarray, |
| image_names: List[str], |
| top_k: int = 5, |
| ) -> List[Dict]: |
| """Pure CLIP retrieval without steering.""" |
| sims = embeddings @ query_emb |
| top_idx = np.argsort(sims)[::-1][:top_k] |
| return [{"image_name": image_names[i], "similarity": float(sims[i])} for i in top_idx] |
|
|
|
|
| def linear_steering( |
| query_emb: np.ndarray, |
| positive_attrs: List[Dict], |
| negative_attrs: List[Dict], |
| alpha: float = 0.4, |
| beta: float = 0.4, |
| ) -> np.ndarray: |
| """q' = q + α·Σ(w_i·p_i) - β·Σ(w_j·n_j)""" |
| steered = query_emb.copy() |
| for attr in positive_attrs: |
| direction = get_text_embedding(f"a photo of {attr['attribute']}") |
| steered += alpha * attr.get("weight", 1.0) * direction |
| for attr in negative_attrs: |
| direction = get_text_embedding(f"a photo of {attr['attribute']}") |
| steered -= beta * attr.get("weight", 1.0) * direction |
| return steered / np.linalg.norm(steered) |
|
|
|
|
| def subspace_steering( |
| query_emb: np.ndarray, |
| positive_attrs: List[Dict], |
| negative_attrs: List[Dict], |
| ) -> np.ndarray: |
| """Contrastive subspace (centroid-based) steering.""" |
| pos_embs = [get_text_embedding(f"a photo of {a['attribute']}") for a in positive_attrs] |
| neg_embs = [get_text_embedding(f"a photo of {a['attribute']}") for a in negative_attrs] |
| if not pos_embs or not neg_embs: |
| return query_emb |
| direction = np.mean(pos_embs, axis=0) - np.mean(neg_embs, axis=0) |
| direction /= np.linalg.norm(direction) |
| steered = query_emb + 0.5 * direction |
| return steered / np.linalg.norm(steered) |
|
|
|
|
| def energy_based_steering( |
| query_emb: np.ndarray, |
| positive_attrs: List[Dict], |
| negative_attrs: List[Dict], |
| n_steps: int = 30, |
| lr: float = 0.05, |
| ) -> np.ndarray: |
| """Energy-based steering via gradient descent.""" |
| steered = query_emb.copy() |
| original = query_emb.copy() |
| pos_embs = [(get_text_embedding(f"a photo of {a['attribute']}"), a.get("weight", 1.0)) for a in positive_attrs] |
| neg_embs = [(get_text_embedding(f"a photo of {a['attribute']}"), a.get("weight", 1.0)) for a in negative_attrs] |
|
|
| for _ in range(n_steps): |
| grad = np.zeros_like(steered) |
| for emb, w in pos_embs: |
| grad -= w * (emb - steered) |
| for emb, w in neg_embs: |
| grad += 0.5 * w * (emb - steered) |
| grad += 0.1 * (steered - original) |
| steered -= lr * grad |
| return steered / np.linalg.norm(steered) |
|
|
|
|
| def weighted_energy_steering( |
| query_emb: np.ndarray, |
| positive_attrs: List[Dict], |
| negative_attrs: List[Dict], |
| n_steps: int = 30, |
| lr: float = 0.05, |
| ) -> np.ndarray: |
| """Weighted energy steering with normalised per-attribute weights.""" |
| steered = query_emb.copy() |
| original = query_emb.copy() |
|
|
| pos_w = [a.get("weight", 1.0) for a in positive_attrs] |
| neg_w = [a.get("weight", 1.0) for a in negative_attrs] |
| if sum(pos_w) > 0: |
| pos_w = [w / sum(pos_w) * len(pos_w) for w in pos_w] |
| if sum(neg_w) > 0: |
| neg_w = [w / sum(neg_w) * len(neg_w) for w in neg_w] |
|
|
| pos_embs = [get_text_embedding(f"a photo of {a['attribute']}") for a in positive_attrs] |
| neg_embs = [get_text_embedding(f"a photo of {a['attribute']}") for a in negative_attrs] |
|
|
| for _ in range(n_steps): |
| grad = np.zeros_like(steered) |
| for emb, w in zip(pos_embs, pos_w): |
| grad -= w * (emb - steered) |
| for emb, w in zip(neg_embs, neg_w): |
| grad += w * (emb - steered) |
| grad += 0.1 * (steered - original) |
| steered -= lr * grad |
| return steered / np.linalg.norm(steered) |
|
|