Spaces:
Sleeping
Sleeping
update hybrid_retriever_tool and add pydantic as a dependency
Browse files
tools/hybrid_retriever_tool.py
CHANGED
|
@@ -4,30 +4,27 @@ from sentence_transformers import SentenceTransformer
|
|
| 4 |
from tavily import TavilyClient
|
| 5 |
from openai import OpenAI
|
| 6 |
from crewai_tools import RagTool
|
|
|
|
| 7 |
import os
|
| 8 |
|
| 9 |
class HybridRetrieverTool(RagTool):
|
| 10 |
name: str = "Hybrid Retriever Tool"
|
| 11 |
description: str = "Combines BM25 keyword scoring with semantic similarity for hybrid retrieval"
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 17 |
-
self.tavily = TavilyClient(api_key=os.getenv("
|
| 18 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 19 |
|
| 20 |
def _build_corpus(self, topic):
|
| 21 |
"""Fetch up-to-date search results."""
|
| 22 |
results = self.tavily.search(query=topic, max_results=30)
|
| 23 |
-
corpus = []
|
| 24 |
-
for r in results.get("results", []):
|
| 25 |
-
content = r.get("content") or ""
|
| 26 |
-
if len(content.strip()) > 0:
|
| 27 |
-
corpus.append(content)
|
| 28 |
return corpus
|
| 29 |
|
| 30 |
-
def _run(self, query: str, top_k=8) -> str:
|
| 31 |
"""
|
| 32 |
Run hybrid search: BM25 + semantic similarity.
|
| 33 |
"""
|
|
@@ -37,7 +34,7 @@ class HybridRetrieverTool(RagTool):
|
|
| 37 |
|
| 38 |
# Lexical relevance
|
| 39 |
bm25 = BM25Okapi([doc.split() for doc in corpus])
|
| 40 |
-
bm25_scores = np.array(bm25.
|
| 41 |
|
| 42 |
# semantic relevance
|
| 43 |
emb_corpus = self.embedder.encode(corpus, convert_to_numpy=True, normalize_embeddings=True)
|
|
|
|
| 4 |
from tavily import TavilyClient
|
| 5 |
from openai import OpenAI
|
| 6 |
from crewai_tools import RagTool
|
| 7 |
+
from pydantic import Field
|
| 8 |
import os
|
| 9 |
|
| 10 |
class HybridRetrieverTool(RagTool):
|
| 11 |
name: str = "Hybrid Retriever Tool"
|
| 12 |
description: str = "Combines BM25 keyword scoring with semantic similarity for hybrid retrieval"
|
| 13 |
+
alpha: float = Field(default=0.6, description="Weight between semantic and lexical scores")
|
| 14 |
+
|
| 15 |
+
def __init__(self, **data):
|
| 16 |
+
super().__init__(**data)
|
| 17 |
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 18 |
+
self.tavily = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
|
| 19 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 20 |
|
| 21 |
def _build_corpus(self, topic):
|
| 22 |
"""Fetch up-to-date search results."""
|
| 23 |
results = self.tavily.search(query=topic, max_results=30)
|
| 24 |
+
corpus = [r.get("content", "").strip() for r in results.get("results", []) if r.get("content")]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
return corpus
|
| 26 |
|
| 27 |
+
def _run(self, query: str, top_k: int = 8) -> str:
|
| 28 |
"""
|
| 29 |
Run hybrid search: BM25 + semantic similarity.
|
| 30 |
"""
|
|
|
|
| 34 |
|
| 35 |
# Lexical relevance
|
| 36 |
bm25 = BM25Okapi([doc.split() for doc in corpus])
|
| 37 |
+
bm25_scores = np.array(bm25.get_scores(query.split()))
|
| 38 |
|
| 39 |
# semantic relevance
|
| 40 |
emb_corpus = self.embedder.encode(corpus, convert_to_numpy=True, normalize_embeddings=True)
|