Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import uuid | |
| from typing import Any, Dict, List | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.embeddings.base import Embeddings | |
| from pinecone import Index # import doesnt work on plane wifi | |
| from pydantic import BaseModel | |
| from reworkd_platform.settings import settings | |
| from reworkd_platform.timer import timed_function | |
| from reworkd_platform.web.api.memory.memory import AgentMemory | |
| OPENAI_EMBEDDING_DIM = 1536 | |
| class Row(BaseModel): | |
| id: str | |
| values: List[float] | |
| metadata: Dict[str, Any] = {} | |
| class QueryResult(BaseModel): | |
| id: str | |
| score: float | |
| metadata: Dict[str, Any] = {} | |
| class PineconeMemory(AgentMemory): | |
| """ | |
| Wrapper around pinecone | |
| """ | |
| def __init__(self, index_name: str, namespace: str = ""): | |
| self.index = Index(settings.pinecone_index_name) | |
| self.namespace = namespace or index_name | |
| def __enter__(self) -> AgentMemory: | |
| self.embeddings: Embeddings = OpenAIEmbeddings( | |
| client=None, # Meta private value but mypy will complain its missing | |
| openai_api_key=settings.openai_api_key, | |
| ) | |
| return self | |
| def __exit__(self, *args: Any, **kwargs: Any) -> None: | |
| pass | |
| def reset_class(self) -> None: | |
| self.index.delete(delete_all=True, namespace=self.namespace) | |
| def add_tasks(self, tasks: List[str]) -> List[str]: | |
| if len(tasks) == 0: | |
| return [] | |
| embeds = self.embeddings.embed_documents(tasks) | |
| if len(tasks) != len(embeds): | |
| raise ValueError("Embeddings and tasks are not the same length") | |
| rows = [ | |
| Row(values=vector, metadata={"text": tasks[i]}, id=str(uuid.uuid4())) | |
| for i, vector in enumerate(embeds) | |
| ] | |
| self.index.upsert( | |
| vectors=[row.dict() for row in rows], namespace=self.namespace | |
| ) | |
| return [row.id for row in rows] | |
| def get_similar_tasks( | |
| self, text: str, score_threshold: float = 0.95 | |
| ) -> List[QueryResult]: | |
| # Get similar tasks | |
| vector = self.embeddings.embed_query(text) | |
| results = self.index.query( | |
| vector=vector, | |
| top_k=5, | |
| include_metadata=True, | |
| include_values=True, | |
| namespace=self.namespace, | |
| ) | |
| return [ | |
| QueryResult(id=row.id, score=row.score, metadata=row.metadata) | |
| for row in getattr(results, "matches", []) | |
| if row.score > score_threshold | |
| ] | |
| def should_use() -> bool: | |
| return False | |