| """ |
| QueryAgent - Intelligent memory retrieval with NLP |
| UPDATED MODULE - Added DistilBERT intent classification |
| """ |
|
|
| import numpy as np |
| from datetime import datetime, timedelta |
|
|
|
|
| class QueryAgent: |
| def __init__(self, memory_agent): |
| """ |
| Initialize query agent with NLP capabilities |
| |
| KEPT: |
| - Time-based filtering |
| - Cosine similarity |
| |
| ADDED: |
| - DistilBERT for intent classification |
| - Hybrid search (text + image) |
| """ |
| self.memory_agent = memory_agent |
| self.model = memory_agent.text_model |
| |
| |
| self.intent_classifier = None |
| self._init_intent_classifier() |
|
|
| def _init_intent_classifier(self): |
| """Initialize DistilBERT for query understanding (NEW)""" |
| try: |
| from transformers import pipeline |
| self.intent_classifier = pipeline( |
| "zero-shot-classification", |
| model="typeform/distilbert-base-uncased-mnli" |
| ) |
| print("[QueryAgent] DistilBERT intent classifier loaded") |
| except Exception as e: |
| print(f"[QueryAgent] DistilBERT not available: {e}") |
| print("[QueryAgent] Using fallback keyword matching") |
|
|
| def classify_intent(self, question): |
| """ |
| Classify query intent using DistilBERT (NEW) |
| |
| Returns: |
| str: Intent category (temporal, object, action, text, general) |
| """ |
| if self.intent_classifier is None: |
| return self._fallback_intent(question) |
| |
| try: |
| candidate_labels = [ |
| "temporal query about time", |
| "object detection query", |
| "action or activity query", |
| "text reading query", |
| "general scene description" |
| ] |
| |
| result = self.intent_classifier(question, candidate_labels) |
| top_intent = result["labels"][0] |
| |
| |
| if "temporal" in top_intent: |
| return "temporal" |
| elif "object" in top_intent: |
| return "object" |
| elif "action" in top_intent: |
| return "action" |
| elif "text" in top_intent: |
| return "text" |
| else: |
| return "general" |
| |
| except Exception as e: |
| print(f"[QueryAgent] Intent classification error: {e}") |
| return self._fallback_intent(question) |
|
|
| def _fallback_intent(self, question): |
| """Fallback keyword-based intent detection (KEPT)""" |
| q = question.lower() |
| |
| if any(word in q for word in ["when", "time", "yesterday", "today", "morning", "evening"]): |
| return "temporal" |
| elif any(word in q for word in ["read", "text", "written", "says"]): |
| return "text" |
| elif any(word in q for word in ["person", "object", "thing", "see"]): |
| return "object" |
| elif any(word in q for word in ["doing", "action", "activity"]): |
| return "action" |
| else: |
| return "general" |
|
|
| def cosine_similarity(self, a, b): |
| """Cosine similarity (KEPT)""" |
| return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
|
|
| @staticmethod |
| def extract_time_window(question): |
| """Extract time constraints from query (KEPT)""" |
| now = datetime.now() |
| q = question.lower() |
|
|
| if "last hour" in q: |
| return now - timedelta(hours=1) |
|
|
| if "last 30 minutes" in q: |
| return now - timedelta(minutes=30) |
|
|
| if "recent" in q or "recently" in q: |
| return now - timedelta(hours=2) |
|
|
| if "today" in q: |
| return now.replace(hour=0, minute=0, second=0) |
|
|
| if "yesterday" in q: |
| start = (now - timedelta(days=1)).replace(hour=0, minute=0, second=0) |
| end = start + timedelta(days=1) |
| return (start, end) |
|
|
| if "this morning" in q: |
| return ( |
| now.replace(hour=6, minute=0, second=0), |
| now.replace(hour=12, minute=0, second=0), |
| ) |
|
|
| if "this evening" in q: |
| return ( |
| now.replace(hour=18, minute=0, second=0), |
| now.replace(hour=22, minute=0, second=0), |
| ) |
|
|
| if "last evening" in q: |
| start = (now - timedelta(days=1)).replace(hour=18, minute=0, second=0) |
| return (start, start.replace(hour=22)) |
|
|
| if "last night" in q: |
| return ( |
| (now - timedelta(days=1)).replace(hour=22, minute=0, second=0), |
| now.replace(hour=6, minute=0, second=0), |
| ) |
|
|
| return None |
|
|
| def ask(self, question, threshold=0.45, use_image_search=False, query_embedding=None): |
| """ |
| Main query method with hybrid search (UPDATED) |
| |
| KEPT: |
| - Time-based filtering |
| - Text similarity search |
| |
| ADDED: |
| - Intent classification |
| - Image-based search option |
| - Hybrid ranking |
| """ |
| memories = self.memory_agent.recall_all() |
| if not memories: |
| return "I don't have any memories yet." |
|
|
| |
| intent = self.classify_intent(question) |
| print(f"[QueryAgent] Detected intent: {intent}") |
|
|
| |
| time_filter = self.extract_time_window(question) |
| filtered = [] |
|
|
| for m in memories: |
| mem_time = datetime.strptime(m["timestamp"], "%Y-%m-%d %H:%M:%S") |
|
|
| if time_filter is None: |
| filtered.append(m) |
| elif isinstance(time_filter, tuple): |
| start, end = time_filter |
| if start <= mem_time < end: |
| filtered.append(m) |
| else: |
| if mem_time >= time_filter: |
| filtered.append(m) |
|
|
| if not filtered: |
| return "I don't recall anything from that time." |
|
|
| |
| results = [] |
| |
| |
| if use_image_search and query_embedding is not None: |
| image_results = self.memory_agent.search_by_image(query_embedding, k=5) |
| for res in image_results: |
| if res["memory"] in filtered: |
| results.append({ |
| "memory": res["memory"], |
| "score": res["similarity"], |
| "source": "image" |
| }) |
| |
| |
| query_embedding_text = self.model.encode(question) |
| |
| for m in filtered: |
| if "text_embedding" not in m: |
| m["text_embedding"] = self.model.encode(m["description"]).tolist() |
| |
| if "importance" not in m: |
| m["importance"] = 1 |
| |
| sim = self.cosine_similarity( |
| query_embedding_text, |
| np.array(m["text_embedding"]) |
| ) |
| |
| if sim >= threshold: |
| |
| if not any(r["memory"]["id"] == m["id"] for r in results): |
| results.append({ |
| "memory": m, |
| "score": sim, |
| "source": "text" |
| }) |
|
|
| if not results: |
| return "I don't recall anything related to that." |
|
|
| |
| results.sort( |
| key=lambda x: (x["score"], x["memory"]["importance"]), |
| reverse=True |
| ) |
|
|
| |
| responses = [] |
| for res in results[:5]: |
| m = res["memory"] |
| responses.append( |
| f"At {m['timestamp']}, {m['description']} " |
| f"(confidence {res['score']:.2f})" |
| ) |
|
|
| return "\n".join(responses) |
|
|