| """ |
| Video Intelligence Platform — Akinator Tree Refinement |
| Decision-tree style interactive narrowing of search results. |
| Each level splits on the most discriminative visual attribute. |
| """ |
| import math |
| from typing import List, Dict, Optional, Tuple |
| from collections import Counter, defaultdict |
|
|
| from .query_engine import QueryResult |
| from .index_store import VideoIndex |
| from .gemini_client import GeminiClient |
|
|
|
|
| class AkinatorNode: |
| """A node in the refinement tree.""" |
|
|
| def __init__(self, results: List[QueryResult], |
| split_attribute: Optional[str] = None, |
| split_value: Optional[str] = None, |
| question: Optional[str] = None): |
| self.results = results |
| self.split_attribute = split_attribute |
| self.split_value = split_value |
| self.question = question |
| self.children: Dict[str, "AkinatorNode"] = {} |
|
|
| @property |
| def is_leaf(self) -> bool: |
| return len(self.children) == 0 |
|
|
| @property |
| def count(self) -> int: |
| return len(self.results) |
|
|
|
|
| class AkinatorRefiner: |
| """ |
| Interactive tree-based refinement of search results. |
| |
| Like Akinator: asks discriminative questions to narrow down |
| which video moments the user is looking for. |
| |
| Algorithm: |
| 1. Start with all candidate results |
| 2. Extract attributes from each candidate (from detections + captions) |
| 3. Compute information gain for each attribute |
| 4. Split on the attribute with highest information gain |
| 5. Ask the user which branch to follow |
| 6. Repeat until results are small enough or user is satisfied |
| """ |
|
|
| def __init__(self, index: VideoIndex, gemini: GeminiClient, |
| threshold: int = 10): |
| self.index = index |
| self.gemini = gemini |
| self.threshold = threshold |
| self.current_node: Optional[AkinatorNode] = None |
| self.history: List[Dict] = [] |
|
|
| def start(self, results: List[QueryResult], query: str) -> Dict: |
| """ |
| Start the Akinator refinement process. |
| |
| Returns: |
| {"status": "refining" | "done", |
| "count": int, |
| "question": str (if refining), |
| "options": list (if refining), |
| "results": list (if done)} |
| """ |
| self.history = [] |
| self.current_node = AkinatorNode(results=results) |
|
|
| if len(results) <= self.threshold: |
| return { |
| "status": "done", |
| "count": len(results), |
| "results": [r.to_dict() for r in results], |
| } |
|
|
| |
| return self._generate_next_question(query) |
|
|
| def answer(self, choice: str, query: str) -> Dict: |
| """ |
| Process user's answer and narrow down results. |
| |
| Args: |
| choice: User's selected option |
| query: Original query for context |
| |
| Returns: |
| Same format as start() |
| """ |
| if self.current_node is None: |
| return {"status": "error", "message": "No active refinement session"} |
|
|
| |
| filtered = self._filter_by_choice( |
| self.current_node.results, |
| self.current_node.split_attribute, |
| choice |
| ) |
|
|
| self.history.append({ |
| "question": self.current_node.question, |
| "answer": choice, |
| "remaining": len(filtered), |
| }) |
|
|
| self.current_node = AkinatorNode( |
| results=filtered, |
| split_value=choice, |
| ) |
|
|
| if len(filtered) <= self.threshold: |
| return { |
| "status": "done", |
| "count": len(filtered), |
| "results": [r.to_dict() for r in filtered], |
| "history": self.history, |
| } |
|
|
| return self._generate_next_question(query) |
|
|
| def _generate_next_question(self, query: str) -> Dict: |
| """Generate the next discriminative question.""" |
| results = self.current_node.results |
| frame_ids = [r.frame_id for r in results] |
|
|
| |
| attributes = self._extract_attributes(results, frame_ids) |
|
|
| if not attributes: |
| return { |
| "status": "done", |
| "count": len(results), |
| "results": [r.to_dict() for r in results], |
| "message": "No more attributes to split on", |
| "history": self.history, |
| } |
|
|
| |
| best_attr, best_gain = self._find_best_split(results, attributes) |
|
|
| if best_attr is None or best_gain < 0.01: |
| return { |
| "status": "done", |
| "count": len(results), |
| "results": [r.to_dict() for r in results], |
| "message": "Attributes are too uniform to split further", |
| "history": self.history, |
| } |
|
|
| |
| try: |
| question_data = self.gemini.generate_refinement_question( |
| query, {best_attr: attributes[best_attr]} |
| ) |
| except Exception: |
| question_data = { |
| "attribute": best_attr, |
| "question": f"Which {best_attr}?", |
| "options": attributes[best_attr][:5], |
| } |
|
|
| self.current_node.split_attribute = best_attr |
| self.current_node.question = question_data.get("question", f"Which {best_attr}?") |
|
|
| return { |
| "status": "refining", |
| "count": len(results), |
| "attribute": best_attr, |
| "question": question_data.get("question", f"Which {best_attr}?"), |
| "options": question_data.get("options", attributes[best_attr][:5]), |
| "history": self.history, |
| } |
|
|
| def _extract_attributes(self, results: List[QueryResult], |
| frame_ids: List[int]) -> Dict[str, List[str]]: |
| """ |
| Extract splittable attributes from results. |
| Combines detection labels + caption-derived attributes. |
| """ |
| attributes = defaultdict(set) |
|
|
| for result in results: |
| |
| for det in result.detections: |
| attributes["object_type"].add(det.lower()) |
|
|
| |
| caption = result.caption.lower() if result.caption else "" |
|
|
| |
| if "indoor" in caption or "inside" in caption or "room" in caption: |
| attributes["location"].add("indoor") |
| if "outdoor" in caption or "outside" in caption or "street" in caption: |
| attributes["location"].add("outdoor") |
|
|
| |
| if any(w in caption for w in ["night", "dark", "evening"]): |
| attributes["time_of_day"].add("night") |
| if any(w in caption for w in ["day", "bright", "sunny", "morning", "afternoon"]): |
| attributes["time_of_day"].add("day") |
|
|
| |
| for color in ["red", "blue", "green", "white", "black", "yellow", "brown", "gray", "orange", "pink", "purple"]: |
| if color in caption: |
| attributes["dominant_color"].add(color) |
|
|
| |
| if any(w in caption for w in ["crowd", "group", "many people", "several people"]): |
| attributes["people_density"].add("many") |
| elif any(w in caption for w in ["person", "man", "woman", "individual"]): |
| attributes["people_density"].add("few") |
| elif "empty" in caption or "no one" in caption: |
| attributes["people_density"].add("none") |
|
|
| |
| for action in ["walking", "running", "sitting", "standing", "driving", "talking", "eating"]: |
| if action in caption: |
| attributes["action"].add(action) |
|
|
| |
| return { |
| k: sorted(list(v)) |
| for k, v in attributes.items() |
| if len(v) >= 2 |
| } |
|
|
| def _find_best_split(self, results: List[QueryResult], |
| attributes: Dict[str, List[str]]) -> Tuple[Optional[str], float]: |
| """ |
| Find the attribute with highest information gain (like a decision tree). |
| """ |
| best_attr = None |
| best_gain = -1.0 |
|
|
| total = len(results) |
| if total == 0: |
| return None, 0.0 |
|
|
| |
| current_entropy = math.log2(total) if total > 1 else 0 |
|
|
| for attr_name, attr_values in attributes.items(): |
| |
| value_counts = Counter() |
| for result in results: |
| matched_values = self._get_attribute_value(result, attr_name) |
| for v in matched_values: |
| if v in attr_values: |
| value_counts[v] += 1 |
|
|
| |
| weighted_entropy = 0 |
| for value, count in value_counts.items(): |
| if count > 0: |
| p = count / total |
| entropy = -p * math.log2(p) if p > 0 and p < 1 else 0 |
| weighted_entropy += (count / total) * entropy |
|
|
| gain = current_entropy - weighted_entropy |
|
|
| |
| balance_bonus = 0 |
| if len(value_counts) >= 2: |
| counts = list(value_counts.values()) |
| min_c, max_c = min(counts), max(counts) |
| if max_c > 0: |
| balance_bonus = min_c / max_c * 0.1 |
|
|
| adjusted_gain = gain + balance_bonus |
|
|
| if adjusted_gain > best_gain: |
| best_gain = adjusted_gain |
| best_attr = attr_name |
|
|
| return best_attr, best_gain |
|
|
| def _get_attribute_value(self, result: QueryResult, attr_name: str) -> List[str]: |
| """Get the value(s) of an attribute for a result.""" |
| caption = result.caption.lower() if result.caption else "" |
|
|
| if attr_name == "object_type": |
| return [d.lower() for d in result.detections] |
|
|
| elif attr_name == "location": |
| values = [] |
| if any(w in caption for w in ["indoor", "inside", "room"]): |
| values.append("indoor") |
| if any(w in caption for w in ["outdoor", "outside", "street"]): |
| values.append("outdoor") |
| return values |
|
|
| elif attr_name == "time_of_day": |
| values = [] |
| if any(w in caption for w in ["night", "dark", "evening"]): |
| values.append("night") |
| if any(w in caption for w in ["day", "bright", "sunny"]): |
| values.append("day") |
| return values |
|
|
| elif attr_name == "dominant_color": |
| return [c for c in ["red", "blue", "green", "white", "black", "yellow", |
| "brown", "gray", "orange", "pink", "purple"] |
| if c in caption] |
|
|
| elif attr_name == "people_density": |
| if any(w in caption for w in ["crowd", "group", "many"]): |
| return ["many"] |
| elif any(w in caption for w in ["person", "man", "woman"]): |
| return ["few"] |
| return ["none"] |
|
|
| elif attr_name == "action": |
| return [a for a in ["walking", "running", "sitting", "standing", |
| "driving", "talking", "eating"] |
| if a in caption] |
|
|
| return [] |
|
|
| def _filter_by_choice(self, results: List[QueryResult], |
| attribute: str, choice: str) -> List[QueryResult]: |
| """Filter results that match the user's chosen attribute value.""" |
| filtered = [] |
| for r in results: |
| values = self._get_attribute_value(r, attribute) |
| if choice.lower() in [v.lower() for v in values]: |
| filtered.append(r) |
|
|
| |
| return filtered if filtered else results |
|
|