""" Video Intelligence Platform — Akinator Tree Refinement Decision-tree style interactive narrowing of search results. Each level splits on the most discriminative visual attribute. """ import math from typing import List, Dict, Optional, Tuple from collections import Counter, defaultdict from .query_engine import QueryResult from .index_store import VideoIndex from .gemini_client import GeminiClient class AkinatorNode: """A node in the refinement tree.""" def __init__(self, results: List[QueryResult], split_attribute: Optional[str] = None, split_value: Optional[str] = None, question: Optional[str] = None): self.results = results self.split_attribute = split_attribute self.split_value = split_value self.question = question self.children: Dict[str, "AkinatorNode"] = {} @property def is_leaf(self) -> bool: return len(self.children) == 0 @property def count(self) -> int: return len(self.results) class AkinatorRefiner: """ Interactive tree-based refinement of search results. Like Akinator: asks discriminative questions to narrow down which video moments the user is looking for. Algorithm: 1. Start with all candidate results 2. Extract attributes from each candidate (from detections + captions) 3. Compute information gain for each attribute 4. Split on the attribute with highest information gain 5. Ask the user which branch to follow 6. Repeat until results are small enough or user is satisfied """ def __init__(self, index: VideoIndex, gemini: GeminiClient, threshold: int = 10): self.index = index self.gemini = gemini self.threshold = threshold # Stop refining when results ≤ this self.current_node: Optional[AkinatorNode] = None self.history: List[Dict] = [] def start(self, results: List[QueryResult], query: str) -> Dict: """ Start the Akinator refinement process. Returns: {"status": "refining" | "done", "count": int, "question": str (if refining), "options": list (if refining), "results": list (if done)} """ self.history = [] self.current_node = AkinatorNode(results=results) if len(results) <= self.threshold: return { "status": "done", "count": len(results), "results": [r.to_dict() for r in results], } # Get attributes and find best split return self._generate_next_question(query) def answer(self, choice: str, query: str) -> Dict: """ Process user's answer and narrow down results. Args: choice: User's selected option query: Original query for context Returns: Same format as start() """ if self.current_node is None: return {"status": "error", "message": "No active refinement session"} # Filter results based on choice filtered = self._filter_by_choice( self.current_node.results, self.current_node.split_attribute, choice ) self.history.append({ "question": self.current_node.question, "answer": choice, "remaining": len(filtered), }) self.current_node = AkinatorNode( results=filtered, split_value=choice, ) if len(filtered) <= self.threshold: return { "status": "done", "count": len(filtered), "results": [r.to_dict() for r in filtered], "history": self.history, } return self._generate_next_question(query) def _generate_next_question(self, query: str) -> Dict: """Generate the next discriminative question.""" results = self.current_node.results frame_ids = [r.frame_id for r in results] # Get available attributes attributes = self._extract_attributes(results, frame_ids) if not attributes: return { "status": "done", "count": len(results), "results": [r.to_dict() for r in results], "message": "No more attributes to split on", "history": self.history, } # Find best split by information gain best_attr, best_gain = self._find_best_split(results, attributes) if best_attr is None or best_gain < 0.01: return { "status": "done", "count": len(results), "results": [r.to_dict() for r in results], "message": "Attributes are too uniform to split further", "history": self.history, } # Generate natural language question via Gemini try: question_data = self.gemini.generate_refinement_question( query, {best_attr: attributes[best_attr]} ) except Exception: question_data = { "attribute": best_attr, "question": f"Which {best_attr}?", "options": attributes[best_attr][:5], } self.current_node.split_attribute = best_attr self.current_node.question = question_data.get("question", f"Which {best_attr}?") return { "status": "refining", "count": len(results), "attribute": best_attr, "question": question_data.get("question", f"Which {best_attr}?"), "options": question_data.get("options", attributes[best_attr][:5]), "history": self.history, } def _extract_attributes(self, results: List[QueryResult], frame_ids: List[int]) -> Dict[str, List[str]]: """ Extract splittable attributes from results. Combines detection labels + caption-derived attributes. """ attributes = defaultdict(set) for result in results: # From detections for det in result.detections: attributes["object_type"].add(det.lower()) # From caption analysis caption = result.caption.lower() if result.caption else "" # Location if "indoor" in caption or "inside" in caption or "room" in caption: attributes["location"].add("indoor") if "outdoor" in caption or "outside" in caption or "street" in caption: attributes["location"].add("outdoor") # Time of day if any(w in caption for w in ["night", "dark", "evening"]): attributes["time_of_day"].add("night") if any(w in caption for w in ["day", "bright", "sunny", "morning", "afternoon"]): attributes["time_of_day"].add("day") # Colors for color in ["red", "blue", "green", "white", "black", "yellow", "brown", "gray", "orange", "pink", "purple"]: if color in caption: attributes["dominant_color"].add(color) # People count if any(w in caption for w in ["crowd", "group", "many people", "several people"]): attributes["people_density"].add("many") elif any(w in caption for w in ["person", "man", "woman", "individual"]): attributes["people_density"].add("few") elif "empty" in caption or "no one" in caption: attributes["people_density"].add("none") # Action for action in ["walking", "running", "sitting", "standing", "driving", "talking", "eating"]: if action in caption: attributes["action"].add(action) # Only keep attributes with 2+ unique values (otherwise they can't split) return { k: sorted(list(v)) for k, v in attributes.items() if len(v) >= 2 } def _find_best_split(self, results: List[QueryResult], attributes: Dict[str, List[str]]) -> Tuple[Optional[str], float]: """ Find the attribute with highest information gain (like a decision tree). """ best_attr = None best_gain = -1.0 total = len(results) if total == 0: return None, 0.0 # Current entropy current_entropy = math.log2(total) if total > 1 else 0 for attr_name, attr_values in attributes.items(): # Count how many results match each value value_counts = Counter() for result in results: matched_values = self._get_attribute_value(result, attr_name) for v in matched_values: if v in attr_values: value_counts[v] += 1 # Calculate weighted entropy after split weighted_entropy = 0 for value, count in value_counts.items(): if count > 0: p = count / total entropy = -p * math.log2(p) if p > 0 and p < 1 else 0 weighted_entropy += (count / total) * entropy gain = current_entropy - weighted_entropy # Prefer attributes that create more balanced splits balance_bonus = 0 if len(value_counts) >= 2: counts = list(value_counts.values()) min_c, max_c = min(counts), max(counts) if max_c > 0: balance_bonus = min_c / max_c * 0.1 adjusted_gain = gain + balance_bonus if adjusted_gain > best_gain: best_gain = adjusted_gain best_attr = attr_name return best_attr, best_gain def _get_attribute_value(self, result: QueryResult, attr_name: str) -> List[str]: """Get the value(s) of an attribute for a result.""" caption = result.caption.lower() if result.caption else "" if attr_name == "object_type": return [d.lower() for d in result.detections] elif attr_name == "location": values = [] if any(w in caption for w in ["indoor", "inside", "room"]): values.append("indoor") if any(w in caption for w in ["outdoor", "outside", "street"]): values.append("outdoor") return values elif attr_name == "time_of_day": values = [] if any(w in caption for w in ["night", "dark", "evening"]): values.append("night") if any(w in caption for w in ["day", "bright", "sunny"]): values.append("day") return values elif attr_name == "dominant_color": return [c for c in ["red", "blue", "green", "white", "black", "yellow", "brown", "gray", "orange", "pink", "purple"] if c in caption] elif attr_name == "people_density": if any(w in caption for w in ["crowd", "group", "many"]): return ["many"] elif any(w in caption for w in ["person", "man", "woman"]): return ["few"] return ["none"] elif attr_name == "action": return [a for a in ["walking", "running", "sitting", "standing", "driving", "talking", "eating"] if a in caption] return [] def _filter_by_choice(self, results: List[QueryResult], attribute: str, choice: str) -> List[QueryResult]: """Filter results that match the user's chosen attribute value.""" filtered = [] for r in results: values = self._get_attribute_value(r, attribute) if choice.lower() in [v.lower() for v in values]: filtered.append(r) # If filtering removed everything (edge case), return all return filtered if filtered else results