| from __future__ import annotations |
|
|
| import traceback |
|
|
| import pandas as pd |
| import streamlit as st |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| from src.context_analysis import analyze_object_contexts |
| from src.ijepa_localization import IJepaPatchLocalizer, iou |
| from src.obstacle_dataset import ( |
| DEFAULT_OBSTACLE_DATASET, |
| load_balanced_obstacle_rows, |
| load_obstacle_image, |
| load_obstacle_rows, |
| parse_yolo_boxes, |
| ) |
| from src.prototypes import build_class_prototypes, guess_objects_with_prototypes |
| from src.small_head import guess_objects_with_head, train_small_head |
| from src.visualization import draw_yolo_with_heatmap |
|
|
|
|
| DEFAULT_SPLIT = "train" |
| DEFAULT_MODEL = "facebook/ijepa_vith14_1k" |
|
|
|
|
| @st.cache_resource(show_spinner=False) |
| def get_localizer(model_name: str) -> IJepaPatchLocalizer: |
| return IJepaPatchLocalizer(model_name=model_name) |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def get_rows( |
| dataset_name: str, |
| split: str, |
| max_samples: int, |
| min_objects: int, |
| sample_mode: str, |
| random_seed: int, |
| ): |
| if sample_mode == "balanced single/multiple": |
| return load_balanced_obstacle_rows(dataset_name, split, max_samples, random_seed=random_seed) |
| return load_obstacle_rows( |
| dataset_name, |
| split, |
| max_samples, |
| min_objects=min_objects, |
| random_seed=random_seed, |
| ) |
|
|
|
|
| def run_localization_probe( |
| dataset_name: str, |
| split: str, |
| model_name: str, |
| max_samples: int, |
| min_objects: int, |
| sample_mode: str, |
| saliency_threshold: float, |
| analyze_context: bool, |
| use_prototypes: bool, |
| prototype_samples: int, |
| train_head: bool, |
| head_train_samples: int, |
| random_seed: int, |
| ): |
| rows = get_rows(dataset_name, split, max_samples, min_objects, sample_mode, random_seed) |
| localizer = get_localizer(model_name) |
| prototypes = {} |
| if use_prototypes: |
| prototype_rows = load_balanced_obstacle_rows( |
| dataset_name, |
| split, |
| prototype_samples, |
| random_seed=random_seed + 10_000, |
| ) |
| prototypes = build_class_prototypes(dataset_name, split, prototype_rows, localizer) |
| trained_head = None |
| if train_head: |
| head_rows = load_balanced_obstacle_rows( |
| dataset_name, |
| split, |
| head_train_samples, |
| random_seed=random_seed + 20_000, |
| ) |
| trained_head = train_small_head(dataset_name, split, head_rows, localizer) |
|
|
| results = [] |
| overlays = [] |
| image_embeddings = [] |
| yolo_labels = [] |
|
|
| progress = st.progress(0) |
| for index, row in enumerate(rows): |
| image = load_obstacle_image(dataset_name, row, split) |
| yolo_boxes = parse_yolo_boxes(row) |
| localization = localizer.localize(image) |
| yolo_xyxy = [box.to_xyxy(*image.size) for box in yolo_boxes] |
| context_results = analyze_object_contexts(image, yolo_boxes, localizer) if analyze_context else [] |
| prototype_guesses = ( |
| guess_objects_with_prototypes(image, yolo_boxes, localizer, prototypes) |
| if use_prototypes |
| else [] |
| ) |
| head_guesses = ( |
| guess_objects_with_head(image, yolo_boxes, localizer, trained_head) |
| if train_head |
| else [] |
| ) |
| candidate_boxes = getattr(localization, "candidate_boxes_xyxy", None) or [localization.box_xyxy] |
| all_ijepa_boxes = candidate_boxes |
| best_iou = max((iou(candidate, box) for candidate in all_ijepa_boxes for box in yolo_xyxy), default=0.0) |
| structure_guess = describe_object_structure(len(candidate_boxes)) |
| yolo_structure = describe_yolo_structure(len(yolo_boxes)) |
| structure_agrees = structure_matches_yolo(len(candidate_boxes), len(yolo_boxes)) |
| class_names = sorted({box.class_name for box in yolo_boxes}) |
| yolo_label = ", ".join(class_names) if class_names else "none" |
| image_embeddings.append(localization.image_embedding) |
| yolo_labels.append(yolo_label) |
|
|
| results.append( |
| { |
| "sample": index, |
| "file_name": row["file_name"], |
| "yolo_objects": yolo_label, |
| "ijepa_object_guess": "pending", |
| "num_yolo_objects": len(yolo_boxes), |
| "context_patterns": ", ".join( |
| sorted({result.context_pattern for result in context_results}) |
| ) |
| if context_results |
| else "not analyzed", |
| "prototype_agreement": round( |
| sum(guess.agreement for guess in prototype_guesses) / len(prototype_guesses), |
| 4, |
| ) |
| if prototype_guesses |
| else None, |
| "small_head_agreement": round( |
| sum(guess.agreement for guess in head_guesses) / len(head_guesses), |
| 4, |
| ) |
| if head_guesses |
| else None, |
| "ijepa_salient_regions": len(candidate_boxes), |
| "ijepa_structure_guess": structure_guess, |
| "yolo_structure": yolo_structure, |
| "structure_agrees": structure_agrees, |
| "best_iou_vs_yolo": round(best_iou, 4), |
| "ijepa_score": round(localization.score, 4), |
| } |
| ) |
| overlays.append( |
| { |
| "sample": index, |
| "image": draw_yolo_with_heatmap( |
| image, |
| yolo_boxes, |
| localization.heatmap, |
| saliency_threshold=saliency_threshold, |
| ), |
| "heatmap": localization.heatmap, |
| "objects": yolo_label, |
| "best_iou": best_iou, |
| "candidate_boxes": candidate_boxes, |
| "num_yolo_boxes": len(yolo_boxes), |
| "structure_guess": structure_guess, |
| "yolo_structure": yolo_structure, |
| "structure_agrees": structure_agrees, |
| "context_results": context_results, |
| "prototype_guesses": prototype_guesses, |
| "head_guesses": head_guesses, |
| "head_train_accuracy": trained_head.train_accuracy if trained_head else None, |
| "head_train_objects": trained_head.train_objects if trained_head else None, |
| "head_parameter_count": trained_head.parameter_count if trained_head else None, |
| "prototype_classes": len(prototypes), |
| "representation_report": build_representation_report( |
| yolo_label, |
| structure_guess, |
| context_results, |
| head_guesses, |
| prototype_guesses, |
| trained_head.parameter_count if trained_head else None, |
| ), |
| } |
| ) |
| progress.progress((index + 1) / len(rows)) |
|
|
| if len(image_embeddings) > 1: |
| similarities = cosine_similarity(image_embeddings) |
| for index in range(len(results)): |
| similarities[index, index] = -1 |
| neighbor = int(similarities[index].argmax()) |
| results[index]["ijepa_object_guess"] = yolo_labels[neighbor] |
| results[index]["nearest_labeled_sample"] = neighbor |
| results[index]["object_guess_similarity"] = round(float(similarities[index, neighbor]), 4) |
| else: |
| results[0]["ijepa_object_guess"] = "needs at least 2 samples" |
| results[0]["nearest_labeled_sample"] = None |
| results[0]["object_guess_similarity"] = None |
|
|
| return pd.DataFrame(results), overlays |
|
|
|
|
| def describe_object_structure(salient_regions: int) -> str: |
| if salient_regions <= 1: |
| return "single focus" |
| if salient_regions <= 3: |
| return "multi-region" |
| return "distributed/group pattern" |
|
|
|
|
| def describe_yolo_structure(yolo_objects: int) -> str: |
| if yolo_objects <= 1: |
| return "single object" |
| if yolo_objects <= 3: |
| return "multiple objects" |
| return "group of objects" |
|
|
|
|
| def structure_matches_yolo(salient_regions: int, yolo_objects: int) -> bool: |
| yolo_is_multi = yolo_objects > 1 |
| ijepa_is_multi = salient_regions > 1 |
| return yolo_is_multi == ijepa_is_multi |
|
|
|
|
| def build_representation_report( |
| yolo_objects: str, |
| scene_structure: str, |
| context_results, |
| head_guesses, |
| prototype_guesses, |
| head_parameter_count: int | None, |
| ) -> str: |
| head_summary = "not trained" |
| if head_guesses: |
| labels = sorted({guess.head_guess for guess in head_guesses}) |
| confidence = sum(guess.confidence for guess in head_guesses) / len(head_guesses) |
| head_summary = f"{', '.join(labels)} (avg confidence {confidence:.2f})" |
|
|
| prototype_summary = "not run" |
| if prototype_guesses: |
| labels = sorted({guess.ijepa_guess for guess in prototype_guesses}) |
| prototype_summary = ", ".join(labels) |
|
|
| context_summary = "not analyzed" |
| if context_results: |
| patterns = sorted({result.context_pattern for result in context_results}) |
| strengths = sorted({result.context_strength for result in context_results}) |
| context_summary = f"{', '.join(patterns)} | strength: {', '.join(strengths)}" |
|
|
| parameter_text = f"{head_parameter_count:,} trainable params" if head_parameter_count else "no trainable head" |
| return ( |
| f"YOLO reference: {yolo_objects}. " |
| f"Tiny classifier: {head_summary}. " |
| f"Prototype match: {prototype_summary}. " |
| f"Scene structure: {scene_structure}. " |
| f"Context: {context_summary}. " |
| f"Head size: {parameter_text}." |
| ) |
|
|
|
|
| st.set_page_config(page_title="JEPA-demo", layout="wide") |
|
|
| st.title("JEPA-demo") |
| st.caption( |
| "I-JEPA patch-representation probe for obstacle localization, compared against YOLO-format " |
| "dataset labels as the benchmark." |
| ) |
|
|
| st.markdown( |
| """ |
| **Representation-first operational vision:** YOLO provides object-level grounding, while |
| frozen I-JEPA provides scene/context representations. A tiny logistic-regression head can |
| test whether those representations are already enough to say what an object likely is. |
| """ |
| ) |
|
|
| with st.expander("What is what?", expanded=True): |
| st.markdown( |
| """ |
| - **YOLO**: benchmark labels and boxes from the dataset. |
| - **I-JEPA saliency**: orange overlay showing where frozen I-JEPA has strong patch-level representation activity. |
| - **Prototype match**: compares an object crop with average I-JEPA embeddings per YOLO class. |
| - **Tiny classifier**: logistic regression trained on frozen I-JEPA crop embeddings. |
| - **Isolated object**: one clear focus point, relatively separate from the scene. |
| - **Group-like scene**: several similar objects or people forming a cluster. |
| - **Context-heavy surroundings**: the area around the object contributes strongly to the scene meaning. |
| - **Multi-region visual structure**: several separate visually important regions, not necessarily one group. |
| """ |
| ) |
|
|
| with st.sidebar: |
| st.header("Experiment") |
| dataset_name = st.text_input("Dataset", value=DEFAULT_OBSTACLE_DATASET) |
| split = DEFAULT_SPLIT |
| model_name = st.text_input("I-JEPA model", value=DEFAULT_MODEL) |
| max_samples = st.slider("Max samples", min_value=1, max_value=25, value=3, step=1) |
| random_seed = st.number_input("Sample seed", min_value=0, value=7, step=1) |
| sample_mode = st.selectbox( |
| "Sample mode", |
| ["balanced single/multiple", "minimum object filter"], |
| ) |
| min_objects = 1 |
| saliency_threshold = st.slider( |
| "Saliency overlay threshold", |
| min_value=0.4, |
| max_value=0.95, |
| value=0.7, |
| step=0.05, |
| ) |
| analyze_context = st.checkbox("Analyze object context", value=True) |
| use_prototypes = st.checkbox("Match with class prototypes", value=True) |
| prototype_samples = st.slider("Prototype reference images", min_value=4, max_value=80, value=12, step=4) |
| train_head = st.checkbox("Train lightweight classifier", value=False) |
| head_train_samples = st.slider( |
| "Classifier training images", |
| min_value=8, |
| max_value=160, |
| value=40, |
| step=8, |
| ) |
| run = st.button("Run", type="primary") |
|
|
| with st.expander("How to read the signals", expanded=False): |
| st.markdown( |
| """ |
| - **Green boxes**: YOLO-format benchmark labels from the dataset. |
| - **Orange overlay**: I-JEPA patch saliency, showing where representation activity is strongest. |
| - **Object guess**: nearest labeled example in I-JEPA embedding space. |
| - **Context pattern**: rough signal for isolated objects, nearby objects, or group/crowd context. |
| - **Context strength**: how strongly object/context/scene embeddings relate. |
| - **Scene structure**: single focus, multi-region, or distributed/group pattern. |
| """ |
| ) |
|
|
| if run: |
| try: |
| with st.status("Running I-JEPA localization probe...", expanded=True): |
| results, overlays = run_localization_probe( |
| dataset_name=dataset_name, |
| split=split, |
| model_name=model_name, |
| max_samples=max_samples, |
| min_objects=min_objects, |
| sample_mode=sample_mode, |
| saliency_threshold=saliency_threshold, |
| analyze_context=analyze_context, |
| use_prototypes=use_prototypes, |
| prototype_samples=prototype_samples, |
| train_head=train_head, |
| head_train_samples=head_train_samples, |
| random_seed=int(random_seed), |
| ) |
| st.session_state["results"] = results |
| st.session_state["overlays"] = overlays |
| st.success("Probe completed.") |
| except Exception as exc: |
| st.error(f"{type(exc).__name__}: {exc}") |
| st.code(traceback.format_exc(limit=8)) |
|
|
| results = st.session_state.get("results") |
| overlays = st.session_state.get("overlays", []) |
|
|
| if results is not None and not results.empty: |
| col1, col2, col3, col4 = st.columns(4) |
| col1.metric("Samples", len(results)) |
| col2.metric("YOLO objects", int(results["num_yolo_objects"].sum())) |
| prototype_classes = overlays[0].get("prototype_classes") if overlays else None |
| col3.metric("Prototype classes", prototype_classes if prototype_classes else "off") |
| head_objects = overlays[0].get("head_train_objects") if overlays else None |
| col4.metric("Classifier train crops", head_objects if head_objects else "off") |
| if overlays and overlays[0].get("head_parameter_count"): |
| st.caption( |
| f"Tiny classifier size: {overlays[0]['head_parameter_count']:,} trainable parameters " |
| "on top of frozen I-JEPA." |
| ) |
|
|
| st.subheader("Run details") |
| st.dataframe(results, width="stretch") |
|
|
| st.subheader("Representation reports") |
| for item in overlays: |
| st.markdown(f"**Sample {item['sample']} - objects: {item['objects']}**") |
| st.info(item["representation_report"]) |
| image_col, detail_col = st.columns([2, 1]) |
| with image_col: |
| st.image( |
| item["image"], |
| caption=f"Green: YOLO benchmark labels. Orange: I-JEPA saliency. Best IoU proxy: {item['best_iou']:.3f}", |
| width="stretch", |
| ) |
| with detail_col: |
| st.metric("YOLO boxes", item["num_yolo_boxes"]) |
| st.metric("I-JEPA salient regions", len(item["candidate_boxes"])) |
| st.metric("Scene structure match", "yes" if item["structure_agrees"] else "no") |
| st.write("YOLO structure") |
| st.info(item["yolo_structure"]) |
| st.write("I-JEPA structure guess") |
| st.info(item["structure_guess"]) |
| if item["context_results"]: |
| st.write("Object context") |
| st.dataframe( |
| pd.DataFrame( |
| [ |
| { |
| "object": result.object_index, |
| "class": result.class_name, |
| "pattern": result.context_pattern, |
| "strength": result.context_strength, |
| "scene_structure": item["structure_guess"], |
| "object_context_similarity": round( |
| result.object_context_similarity, 3 |
| ), |
| "scene_context_similarity": round(result.scene_context_similarity, 3), |
| } |
| for result in item["context_results"] |
| ] |
| ), |
| width="stretch", |
| hide_index=True, |
| ) |
| if item["prototype_guesses"]: |
| st.write("YOLO vs I-JEPA prototype") |
| st.dataframe( |
| pd.DataFrame( |
| [ |
| { |
| "object": guess.object_index, |
| "yolo_label": guess.yolo_label, |
| "ijepa_guess": guess.ijepa_guess, |
| "agreement": "yes" if guess.agreement else "no", |
| "similarity": round(guess.similarity, 3), |
| } |
| for guess in item["prototype_guesses"] |
| ] |
| ), |
| width="stretch", |
| hide_index=True, |
| ) |
| if item["head_guesses"]: |
| st.write("YOLO vs small head") |
| if item["head_train_accuracy"] is not None: |
| st.caption( |
| f"Head train accuracy: {item['head_train_accuracy']:.0%} " |
| f"on {item['head_train_objects']} object crops. " |
| f"Trainable parameters: {item['head_parameter_count']:,}" |
| ) |
| st.dataframe( |
| pd.DataFrame( |
| [ |
| { |
| "object": guess.object_index, |
| "yolo_label": guess.yolo_label, |
| "head_guess": guess.head_guess, |
| "agreement": "yes" if guess.agreement else "no", |
| "confidence": round(guess.confidence, 3), |
| } |
| for guess in item["head_guesses"] |
| ] |
| ), |
| width="stretch", |
| hide_index=True, |
| ) |
| else: |
| st.info("Run a small sample first. The first I-JEPA model download can take a while.") |
|
|