Adán Flores Ramírez
Frida Command interpreter to Gradio
a3643ce
#!/usr/bin/env python3
import json
import os
from enum import Enum
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from pathlib import Path
from typing import ClassVar, Dict, Optional
from pydantic import BaseModel, ValidationError
from types import SimpleNamespace
from termcolor import colored
from datetime import datetime
# Check if embeddings are enabled
EMBEDDINGS_ENABLED = os.getenv("DISABLE_EMBEDDINGS", "false").lower() != "true"
if EMBEDDINGS_ENABLED:
from command_interpreter.embeddings.chroma_adapter import ChromaAdapter
class MetadataProfile(str, Enum):
ITEMS = "items"
LOCATIONS = "locations"
ACTIONS = "actions"
TEC_KNOWLEDGE = "tec_knowledge"
# Metadata validation model for metadata
class MetadataModel(BaseModel):
shelve: Optional[str] = ""
category: Optional[str] = None
context: Optional[str] = ""
result: Optional[str] = None
status: Optional[int] = None
timestamp: Optional[str] = None
subarea: Optional[str] = None
embeddings: Optional[list] = None
items_inside: Optional[str] = None
action: Optional[str] = None
command: Optional[str] = None
PROFILES: ClassVar[Dict[MetadataProfile, Dict[str, str]]] = {
MetadataProfile.ITEMS: {"context": " item for household use"},
MetadataProfile.LOCATIONS: {"context": " house locations"},
MetadataProfile.ACTIONS: {"context": " human actions"},
MetadataProfile.TEC_KNOWLEDGE: {"context": " team knowledge"},
}
@classmethod
def with_profile(
cls, profile: MetadataProfile = MetadataProfile.ITEMS, **overrides
):
base = cls.PROFILES.get(profile, {})
data = {**base, **overrides}
return cls(**data)
class Embeddings():
def __init__(self):
# print("Initializing categorization node.")
# Initialize ChromaAdapter (handles Chroma client and embedding functions)
if not EMBEDDINGS_ENABLED:
print(colored("⚠️ Embeddings disabled (DISABLE_EMBEDDINGS=true). Query features unavailable.", "yellow"))
self.chroma_adapter = None
return
self.chroma_adapter = ChromaAdapter()
self.build_embeddings()
def add_entry_callback(self, request):
"""Service callback to add items to ChromaDB"""
if not EMBEDDINGS_ENABLED or self.chroma_adapter is None:
print(colored("⚠️ Embeddings disabled. Skipping add_entry.", "yellow"))
return
try:
if request.metadata:
metadatas_ = json.loads(request.metadata)
else:
metadatas_ = request.metadata
# Ensure documents is a list
documents = (
request.document
if isinstance(request.document, list)
else [request.document]
)
metadatas = metadatas_ if metadatas_ else [{} for _ in documents]
metadata_objects = []
# Normalize and validate all metadata entries using the profile
for meta in metadatas:
try:
metadata_parsed = MetadataModel.with_profile(
request.collection, **meta
)
metadata_objects.append(metadata_parsed.model_dump())
except Exception as e:
self.get_logger().error(
f"Failed to process metadata entry: {meta}{str(e)}"
)
raise
documents = self.clean(documents)
# Inject context into documents and preserve original names
for i, (doc, meta) in enumerate(zip(documents, metadata_objects)):
meta["original_name"] = doc
context = meta.get("context")
if context:
documents[i] = f"{doc} {context}"
# self.get_logger().info(f"This is the request that is reaching{(request.collection, documents, metadata_objects)}")
# self.get_logger().info("Adding entries to ChromaDB")
if request.collection == "closest_items":
self.chroma_adapter._get_or_create_collection("closest_items")
self.chroma_adapter.add_entries(
request.collection, documents, metadata_objects
)
print(colored("💾 Database: Entry added successfully", "blue", attrs=['bold']))
except Exception as e:
print(colored(f"❌ Database Error: Failed to add item - {str(e)}", "red", attrs=['bold']))
return
def query_entry_callback(self, request):
"""Service callback to query items from ChromaDB"""
# print("Query Entry request received")
if not EMBEDDINGS_ENABLED or self.chroma_adapter is None:
print(colored("⚠️ Embeddings disabled. Returning empty query results.", "yellow"))
return SimpleNamespace(grouped_results=[], ungrouped_results=[])
try:
if request.collection == "items":
context = MetadataModel.PROFILES[MetadataProfile.ITEMS]["context"]
print(colored("🔍 Database: Querying 'items' collection", "blue"))
elif request.collection == "locations":
context = MetadataModel.PROFILES[MetadataProfile.LOCATIONS]["context"]
print(colored("🔍 Database: Querying 'locations' collection", "blue"))
elif request.collection == "actions":
context = MetadataModel.PROFILES[MetadataProfile.ACTIONS]["context"]
print(colored("🔍 Database: Querying 'actions' collection", "blue"))
elif request.collection == "tec_knowledge":
context = MetadataModel.PROFILES[MetadataProfile.TEC_KNOWLEDGE]["context"]
print(colored("🔍 Database: Querying 'tec_knowledge' collection", "blue"))
else:
context = ""
grouped_results = []
# print(f"Query Entry request received {(request.query)}")
for query in request.query:
query_with_context = query + context
if request.collection == "command_history":
results_raw = self.chroma_adapter.query(
request.collection, [query_with_context], request.topk
)
else:
results_raw = self.chroma_adapter.query(
request.collection, [query_with_context], request.topk
)
distances = results_raw.get("distances", [[]])
if distances is None:
distances = [[]]
docs = results_raw.get("documents", [[]])
metas = results_raw.get("metadatas", [[]])
formatted_results = []
# Convert embeddings to a list of lists
# embeddings = [embedding.tolist() for embedding in embeddings]
if request.collection == "command_history":
for doc, meta in zip(docs, metas):
if isinstance(meta, list):
meta = meta[0]
entry = {
"document": doc,
"metadata": meta,
}
if "original_name" in meta:
entry["document"] = meta["original_name"]
formatted_results.append(entry)
else:
for doc, meta, dist in zip(docs, metas, distances):
if isinstance(meta, list):
meta = meta[0]
entry = {
"document": doc,
"metadata": meta,
"distance": dist,
}
if "original_name" in meta:
entry["document"] = meta["original_name"]
formatted_results.append(entry)
grouped_results.append({"query": query, "results": formatted_results})
results = [json.dumps(entry) for entry in grouped_results]
success = bool(grouped_results)
if grouped_results:
print(colored("✅ Database: Query successful", "blue", attrs=['bold']))
else:
print(colored("⚠️ Database: No matching items found", "yellow", attrs=['bold']))
# print("Query request handled")
except Exception as e:
success = False
message = f"Failed to query items: {str(e)}"
print(colored(f"❌ Database Error: {message}", "red", attrs=['bold']))
if request.collection == "closest_items":
self.chroma_adapter.delete_collection("closest_items")
return results, success
def build_embeddings_callback(self, request, response):
"""Method to build embeddings for the household items data"""
if not EMBEDDINGS_ENABLED or self.chroma_adapter is None:
response.success = False
response.message = "Embeddings are disabled"
return response
try:
# Call the build_embeddings_callback of ChromaAdapter to handle the actual embedding process
if request.rebuild:
self.get_logger().info("Rebuilding embeddings")
self.chroma_adapter.remove_all_collections()
self.build_embeddings()
else:
self.build_embeddings()
response.success = True
response.message = "Embeddings built successfully"
self.get_logger().info("Build request handled successfully")
except Exception as e:
response.success = False
response.message = f"Error while building embeddings: {str(e)}"
self.get_logger().error(f"Error while building embeddings: {str(e)}")
return response
def build_embeddings(self):
"""
Method to build embeddings for household use.
Reads JSON files from the designated dataframes folder,
and for each file:
- Reads documents and (if available) metadata.
- Gets or creates a corresponding collection.
- Adds entries to the collection via the add_entries method,
which will process documents and metadata (adding "original_name",
appending "context", and cleaning metadata) automatically.
"""
if not EMBEDDINGS_ENABLED or self.chroma_adapter is None:
return
# Get the directory of the current script
script_dir = Path(__file__).resolve().parent
# Define the folder where the CSV files are located
dataframes_folder = script_dir / "../embeddings/dataframes"
# Ensure the folder exists
if not (dataframes_folder.exists() and dataframes_folder.is_dir()):
raise FileNotFoundError(
f"The folder {dataframes_folder} does not exist or is not a directory."
)
# Get all json files in the folder
dataframes = [
file.resolve()
for file in dataframes_folder.iterdir()
if file.suffix == ".json"
]
# Check if there are any JSON files
if not dataframes:
raise FileNotFoundError(
f"No JSON files found in the folder {dataframes_folder}."
)
collections = {}
for file in dataframes:
documents = []
metadatas_ = []
collection_name = self.chroma_adapter._sanitize_collection_name(file.stem)
collections_ = self.chroma_adapter.list_collections()
if collection_name in collections_:
continue
# print("Processing file:", file)
# Read the JSON file into a Python dictionary
with open(file, "r") as f:
data = json.load(f)
for dict in data:
document = dict["document"]
if "metadata" in dict:
metadata = dict["metadata"]
[document, metadata] = self.add_basics(document, metadata)
else:
metadata = {}
[document, metadata] = self.add_basics(document, metadata)
metadatas_.append(metadata)
documents.append(dict["document"])
# Sanitize and get or create the collection
collection_name = self.chroma_adapter._sanitize_collection_name(file.stem)
collections[collection_name] = (
self.chroma_adapter._get_or_create_collection(collection_name)
)
# Add entries to the collection
self.chroma_adapter.add_entries(collection_name, documents, metadatas_)
self.add_locations()
self.chroma_adapter._get_or_create_collection("command_history")
# self.print_all_collections()
return
def add_command_history(self, command, result, status):
if not EMBEDDINGS_ENABLED or self.chroma_adapter is None:
return
collection = "command_history"
document = [command.action]
metadata = [
{
"command": str(command),
"result": result,
"status": status,
"timestamp": datetime.now().isoformat(),
}
]
request = SimpleNamespace(
document=document, metadata=json.dumps(metadata), collection=collection
)
self.add_entry_callback(request)
def add_locations(self):
collection_name = "locations"
collections_ = self.chroma_adapter.list_collections()
if collection_name in collections_:
self.chroma_adapter.delete_collection(collection_name)
areas_document = []
areas_metadatas = []
script_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(script_dir, "maps", "areas.json")
with open(file_path, "r") as file:
self.areas = json.load(file)
for area in self.areas:
for subarea in self.areas[area]:
if subarea == "safe_place":
subarea = ""
areas_document.append(area + " " + subarea)
areas_metadatas.append(
{"context": "house locations", "area": area, "subarea": subarea}
)
self.chroma_adapter._get_or_create_collection("locations")
self.chroma_adapter.add_entries("locations", areas_document, areas_metadatas)
return
def add_basics(self, documents, metadatas):
# Inject context and sanitize document content
metadatas["original_name"] = documents
if "context" in metadatas:
context = metadatas.get("context")
else:
context = ""
documents = f"{documents} {context}" if context else documents
return documents, metadatas
def clean(self, documents):
# If it's a string that looks like a list -> try parsing it
if (
isinstance(documents, str)
and documents.strip().startswith("[")
and documents.strip().endswith("]")
):
try:
parsed = json.loads(documents.replace("'", '"')) # Handle single quotes
if isinstance(parsed, list):
print("document after cleaning:", documents)
return " ".join(str(x) for x in parsed)
except json.JSONDecodeError:
pass # Leave it as-is if it fails to parse
# Default case: just return the string
return documents
def print_all_collections(self):
"""Prints all collections and their contents in ChromaDB v0.6.0+"""
try:
collection_obj = self.chroma_adapter.client.list_collections()
if not collection_obj:
print("No collections found.")
return
for obj in collection_obj:
name = obj.name
print(f"--- Collection: '{name}' ---")
try:
collection = self.chroma_adapter.client.get_collection(name=name)
results = collection.get(include=["documents", "metadatas"])
docs = results.get("documents", [])
metas = results.get("metadatas", [])
if not docs:
print("(Empty collection)")
continue
for idx, (doc, meta) in enumerate(zip(docs, metas)):
print(f"[{idx}] Document: {doc}")
print(f" Metadata: {meta}")
except Exception as e:
print(
f"Failed to access collection '{name}': {str(e)}"
)
except Exception as e:
print(f"Failed to list collections: {str(e)}")
def _query_(self, query: str, collection: str, top_k: int = 1) -> list[str]:
# Wrap the query in a list so that the field receives a sequence of strings.
if not EMBEDDINGS_ENABLED or self.chroma_adapter is None:
return []
request = SimpleNamespace(query=[query], collection=collection, topk=top_k)
results, success = self.query_entry_callback(request)
if collection == "command_history":
print(colored("🔍 Database: Querying command history", "blue"))
results_loaded = json.loads(results[0])
sorted_results = sorted(
results_loaded["results"], key=lambda x: x["metadata"]["timestamp"], reverse=True
)
results_list = sorted_results[:top_k]
else:
results_loaded = json.loads(results[0])
results_list = results_loaded["results"]
return results_list
def find_closest(self, documents: list, query: str, top_k: int = 1) -> list[str]:
"""
Method to find the closest item to the query.
Args:
documents: the documents to search among
query: the query to search for
Returns:
Status: the status of the execution
list[str]: the results of the query
"""
request = SimpleNamespace(
query=[query], collection="closest_items", topk=top_k
)
self.add_entry_callback(request)
Results = self._query_(query, "closest_items", top_k)
Results = self.get_name(Results)
print(colored(f"🎯 Database: find_closest result for '{query}': {str(Results)}", "blue", attrs=['bold']))
return Results
def delete_collection(self, collection_name: str):
"""
Deletes a collection from the ChromaDB.
Args:
collection_name (str): The name of the collection to delete.
"""
if not EMBEDDINGS_ENABLED or self.chroma_adapter is None:
return
try:
self.chroma_adapter.delete_collection(collection_name)
print(colored(f"🗑️ Database: Collection '{collection_name}' deleted successfully", "blue", attrs=['bold']))
except Exception as e:
print(colored(f"❌ Database Error: Failed to delete collection '{collection_name}' - {str(e)}", "red", attrs=['bold']))
def query_item(self, query: str, top_k: int = 1) -> list[str]:
return self._query_(query, "items", top_k)
def query_location(self, query: str, top_k: int = 1) -> list[str]:
return self._query_(query, "locations", top_k)
def query_command_history(self, query: str, top_k: int = 1) -> list[str]:
return self._query_(query, "command_history", top_k)
def query_tec_knowledge(self, query: str, top_k: int = 1) -> list[str]:
return self._query_(query, "tec_knowledge", top_k)
def query_frida_knowledge(self, query: str, top_k: int = 1) -> list[str]:
return self._query_(query, "frida_knowledge", top_k)
def query_roborregos_knowledge(self, query: str, top_k: int = 1) -> list[str]:
return self._query_(query, "roborregos_knowledge", top_k)
def get_metadata_key(self, query_result, field: str):
"""
Extracts the field from the metadata of a query result.
Args:
query_result (tuple): The query result tuple (status, list of JSON strings)
Returns:
list: The 'context' field from metadata, or empty list if not found
"""
try:
key_list = []
for result in query_result:
metadata = result["metadata"]
if isinstance(metadata, list) and metadata:
metadata = metadata[0]
result_key = metadata.get(field, "") # safely get 'field'
key_list.append(result_key)
return key_list
except (IndexError, KeyError, json.JSONDecodeError) as e:
print(f"Failed to extract context: {str(e)}")
return []
def get_subarea(self, query_result):
result = self.get_metadata_key(query_result, "subarea")
return result[0] if result else ""
def get_area(self, query_result):
result = self.get_metadata_key(query_result, "area")
return result[0] if result else ""
def get_context(self, query_result):
return self.get_metadata_key(query_result, "context")
def get_command(self, query_result):
return self.get_metadata_key(query_result, "command")
def get_result(self, query_result):
return self.get_metadata_key(query_result, "result")
def get_status(self, query_result):
return self.get_metadata_key(query_result, "status")
def get_name(self, query_result):
result = self.get_metadata_key(query_result, "original_name")
return result[0] if result else ""
def main():
embeddings = Embeddings()
#embeddings.print_all_collections()
# results = embeddings.query_item("soda")
# results = embeddings.query_item("soda")
# name = embeddings.get_name(results)
# name = embeddings.get_name(results)
# context = embeddings.get_context(results)
embeddings.add_command_history(
command= SimpleNamespace(
action= "get me a soda"),
result= 'Succesfull',
status= 1
)
results = embeddings.query_command_history("get me a soda")
name = embeddings.get_name(results)
context = embeddings.get_context(results)
print("Success:", results)
print("Name:", name)
embeddings.delete_collection("command_history")
results = embeddings.query_command_history("get me a soda")
name = embeddings.get_name(results)
context = embeddings.get_context(results)
print("Success:", results)
print("Name:", name)
# print("Success:", results)
# print("Name:", name)
# print("Context:", context)
# results = embeddings.query_location("start_location")
# area = embeddings.get_area(results)
# subarea = embeddings.get_subarea(results)
# context = embeddings.get_context(results)
# print("Success:", results)
# print("Location: " + str(area)+ (" -> " + str(subarea) if subarea else ""))
# print("Context:", context)
if __name__ == "__main__":
main()