Technologic101 commited on
Commit
e9ce170
·
1 Parent(s): fe041cb

task: add langgraph retriever tool

Browse files
Files changed (1) hide show
  1. src/tools/design_retriever.py +25 -0
src/tools/design_retriever.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ from langchain.tools import BaseTool
3
+ from chains.design_rag import DesignRAG
4
+ from pydantic import Field
5
+ import json
6
+
7
+ class DesignRetrieverTool(BaseTool):
8
+ """Tool for retrieving similar designs based on requirements."""
9
+
10
+ name: str = "design_retriever"
11
+ description: str = "Retrieves similar designs based on style requirements"
12
+ rag: DesignRAG = Field(description="Design RAG system for retrieving similar designs")
13
+
14
+ def __init__(self, rag: DesignRAG):
15
+ """Initialize the tool with a DesignRAG instance."""
16
+ super().__init__(rag=rag)
17
+
18
+ def _run(self, requirements: Dict, num_examples: int = 3) -> str:
19
+ """Sync version - not used but required by BaseTool"""
20
+ raise NotImplementedError("Use async version")
21
+
22
+ async def _arun(self, requirements: Dict, num_examples: int = 3) -> str:
23
+ """Retrieve similar designs based on requirements"""
24
+ print(f"Retrieving {num_examples} similar designs")
25
+ return await self.rag.query_similar_designs(requirements, num_examples)