Spaces:
Runtime error
Runtime error
| """ | |
| .. warning:: | |
| Beta Feature! | |
| **Cache** provides an optional caching layer for LLMs. | |
| Cache is useful for two reasons: | |
| - It can save you money by reducing the number of API calls you make to the LLM | |
| provider if you're often requesting the same completion multiple times. | |
| - It can speed up your application by reducing the number of API calls you make | |
| to the LLM provider. | |
| Cache directly competes with Memory. See documentation for Pros and Cons. | |
| **Class hierarchy:** | |
| .. code-block:: | |
| BaseCache --> <name>Cache # Examples: InMemoryCache, RedisCache, GPTCache | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import inspect | |
| import json | |
| import logging | |
| import uuid | |
| import warnings | |
| from datetime import timedelta | |
| from functools import lru_cache | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Callable, | |
| Dict, | |
| List, | |
| Optional, | |
| Tuple, | |
| Type, | |
| Union, | |
| cast, | |
| ) | |
| from sqlalchemy import Column, Integer, Row, String, create_engine, select | |
| from sqlalchemy.engine.base import Engine | |
| from sqlalchemy.orm import Session | |
| try: | |
| from sqlalchemy.orm import declarative_base | |
| except ImportError: | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from langchain_core.caches import RETURN_VAL_TYPE, BaseCache | |
| from langchain_core.embeddings import Embeddings | |
| from langchain_core.load.dump import dumps | |
| from langchain_core.load.load import loads | |
| from langchain_core.outputs import ChatGeneration, Generation | |
| from langchain.llms.base import LLM, get_prompts | |
| from langchain.utils import get_from_env | |
| from langchain.vectorstores.redis import Redis as RedisVectorstore | |
| logger = logging.getLogger(__file__) | |
| if TYPE_CHECKING: | |
| import momento | |
| from cassandra.cluster import Session as CassandraSession | |
| def _hash(_input: str) -> str: | |
| """Use a deterministic hashing approach.""" | |
| return hashlib.md5(_input.encode()).hexdigest() | |
| def _dump_generations_to_json(generations: RETURN_VAL_TYPE) -> str: | |
| """Dump generations to json. | |
| Args: | |
| generations (RETURN_VAL_TYPE): A list of language model generations. | |
| Returns: | |
| str: Json representing a list of generations. | |
| Warning: would not work well with arbitrary subclasses of `Generation` | |
| """ | |
| return json.dumps([generation.dict() for generation in generations]) | |
| def _load_generations_from_json(generations_json: str) -> RETURN_VAL_TYPE: | |
| """Load generations from json. | |
| Args: | |
| generations_json (str): A string of json representing a list of generations. | |
| Raises: | |
| ValueError: Could not decode json string to list of generations. | |
| Returns: | |
| RETURN_VAL_TYPE: A list of generations. | |
| Warning: would not work well with arbitrary subclasses of `Generation` | |
| """ | |
| try: | |
| results = json.loads(generations_json) | |
| return [Generation(**generation_dict) for generation_dict in results] | |
| except json.JSONDecodeError: | |
| raise ValueError( | |
| f"Could not decode json to list of generations: {generations_json}" | |
| ) | |
| def _dumps_generations(generations: RETURN_VAL_TYPE) -> str: | |
| """ | |
| Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation` | |
| Args: | |
| generations (RETURN_VAL_TYPE): A list of language model generations. | |
| Returns: | |
| str: a single string representing a list of generations. | |
| This function (+ its counterpart `_loads_generations`) rely on | |
| the dumps/loads pair with Reviver, so are able to deal | |
| with all subclasses of Generation. | |
| Each item in the list can be `dumps`ed to a string, | |
| then we make the whole list of strings into a json-dumped. | |
| """ | |
| return json.dumps([dumps(_item) for _item in generations]) | |
| def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]: | |
| """ | |
| Deserialization of a string into a generic RETURN_VAL_TYPE | |
| (i.e. a sequence of `Generation`). | |
| See `_dumps_generations`, the inverse of this function. | |
| Args: | |
| generations_str (str): A string representing a list of generations. | |
| Compatible with the legacy cache-blob format | |
| Does not raise exceptions for malformed entries, just logs a warning | |
| and returns none: the caller should be prepared for such a cache miss. | |
| Returns: | |
| RETURN_VAL_TYPE: A list of generations. | |
| """ | |
| try: | |
| generations = [loads(_item_str) for _item_str in json.loads(generations_str)] | |
| return generations | |
| except (json.JSONDecodeError, TypeError): | |
| # deferring the (soft) handling to after the legacy-format attempt | |
| pass | |
| try: | |
| gen_dicts = json.loads(generations_str) | |
| # not relying on `_load_generations_from_json` (which could disappear): | |
| generations = [Generation(**generation_dict) for generation_dict in gen_dicts] | |
| logger.warning( | |
| f"Legacy 'Generation' cached blob encountered: '{generations_str}'" | |
| ) | |
| return generations | |
| except (json.JSONDecodeError, TypeError): | |
| logger.warning( | |
| f"Malformed/unparsable cached blob encountered: '{generations_str}'" | |
| ) | |
| return None | |
| class InMemoryCache(BaseCache): | |
| """Cache that stores things in memory.""" | |
| def __init__(self) -> None: | |
| """Initialize with empty cache.""" | |
| self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {} | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| return self._cache.get((prompt, llm_string), None) | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| self._cache[(prompt, llm_string)] = return_val | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear cache.""" | |
| self._cache = {} | |
| Base = declarative_base() | |
| class FullLLMCache(Base): # type: ignore | |
| """SQLite table for full LLM Cache (all generations).""" | |
| __tablename__ = "full_llm_cache" | |
| prompt = Column(String, primary_key=True) | |
| llm = Column(String, primary_key=True) | |
| idx = Column(Integer, primary_key=True) | |
| response = Column(String) | |
| class SQLAlchemyCache(BaseCache): | |
| """Cache that uses SQAlchemy as a backend.""" | |
| def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache): | |
| """Initialize by creating all tables.""" | |
| self.engine = engine | |
| self.cache_schema = cache_schema | |
| self.cache_schema.metadata.create_all(self.engine) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| stmt = ( | |
| select(self.cache_schema.response) | |
| .where(self.cache_schema.prompt == prompt) # type: ignore | |
| .where(self.cache_schema.llm == llm_string) | |
| .order_by(self.cache_schema.idx) | |
| ) | |
| with Session(self.engine) as session: | |
| rows = session.execute(stmt).fetchall() | |
| if rows: | |
| try: | |
| return [loads(row[0]) for row in rows] | |
| except Exception: | |
| logger.warning( | |
| "Retrieving a cache value that could not be deserialized " | |
| "properly. This is likely due to the cache being in an " | |
| "older format. Please recreate your cache to avoid this " | |
| "error." | |
| ) | |
| # In a previous life we stored the raw text directly | |
| # in the table, so assume it's in that format. | |
| return [Generation(text=row[0]) for row in rows] | |
| return None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update based on prompt and llm_string.""" | |
| items = [ | |
| self.cache_schema(prompt=prompt, llm=llm_string, response=dumps(gen), idx=i) | |
| for i, gen in enumerate(return_val) | |
| ] | |
| with Session(self.engine) as session, session.begin(): | |
| for item in items: | |
| session.merge(item) | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear cache.""" | |
| with Session(self.engine) as session: | |
| session.query(self.cache_schema).delete() | |
| session.commit() | |
| class SQLiteCache(SQLAlchemyCache): | |
| """Cache that uses SQLite as a backend.""" | |
| def __init__(self, database_path: str = ".langchain.db"): | |
| """Initialize by creating the engine and all tables.""" | |
| engine = create_engine(f"sqlite:///{database_path}") | |
| super().__init__(engine) | |
| class UpstashRedisCache(BaseCache): | |
| """Cache that uses Upstash Redis as a backend.""" | |
| def __init__(self, redis_: Any, *, ttl: Optional[int] = None): | |
| """ | |
| Initialize an instance of UpstashRedisCache. | |
| This method initializes an object with Upstash Redis caching capabilities. | |
| It takes a `redis_` parameter, which should be an instance of an Upstash Redis | |
| client class, allowing the object to interact with Upstash Redis | |
| server for caching purposes. | |
| Parameters: | |
| redis_: An instance of Upstash Redis client class | |
| (e.g., Redis) used for caching. | |
| This allows the object to communicate with | |
| Redis server for caching operations on. | |
| ttl (int, optional): Time-to-live (TTL) for cached items in seconds. | |
| If provided, it sets the time duration for how long cached | |
| items will remain valid. If not provided, cached items will not | |
| have an automatic expiration. | |
| """ | |
| try: | |
| from upstash_redis import Redis | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import upstash_redis python package. " | |
| "Please install it with `pip install upstash_redis`." | |
| ) | |
| if not isinstance(redis_, Redis): | |
| raise ValueError("Please pass in Upstash Redis object.") | |
| self.redis = redis_ | |
| self.ttl = ttl | |
| def _key(self, prompt: str, llm_string: str) -> str: | |
| """Compute key from prompt and llm_string""" | |
| return _hash(prompt + llm_string) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| generations = [] | |
| # Read from a HASH | |
| results = self.redis.hgetall(self._key(prompt, llm_string)) | |
| if results: | |
| for _, text in results.items(): | |
| generations.append(Generation(text=text)) | |
| return generations if generations else None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| for gen in return_val: | |
| if not isinstance(gen, Generation): | |
| raise ValueError( | |
| "UpstashRedisCache supports caching of normal LLM generations, " | |
| f"got {type(gen)}" | |
| ) | |
| if isinstance(gen, ChatGeneration): | |
| warnings.warn( | |
| "NOTE: Generation has not been cached. UpstashRedisCache does not" | |
| " support caching ChatModel outputs." | |
| ) | |
| return | |
| # Write to a HASH | |
| key = self._key(prompt, llm_string) | |
| mapping = { | |
| str(idx): generation.text for idx, generation in enumerate(return_val) | |
| } | |
| self.redis.hset(key=key, values=mapping) | |
| if self.ttl is not None: | |
| self.redis.expire(key, self.ttl) | |
| def clear(self, **kwargs: Any) -> None: | |
| """ | |
| Clear cache. If `asynchronous` is True, flush asynchronously. | |
| This flushes the *whole* db. | |
| """ | |
| asynchronous = kwargs.get("asynchronous", False) | |
| if asynchronous: | |
| asynchronous = "ASYNC" | |
| else: | |
| asynchronous = "SYNC" | |
| self.redis.flushdb(flush_type=asynchronous) | |
| class RedisCache(BaseCache): | |
| """Cache that uses Redis as a backend.""" | |
| def __init__(self, redis_: Any, *, ttl: Optional[int] = None): | |
| """ | |
| Initialize an instance of RedisCache. | |
| This method initializes an object with Redis caching capabilities. | |
| It takes a `redis_` parameter, which should be an instance of a Redis | |
| client class, allowing the object to interact with a Redis | |
| server for caching purposes. | |
| Parameters: | |
| redis_ (Any): An instance of a Redis client class | |
| (e.g., redis.Redis) used for caching. | |
| This allows the object to communicate with a | |
| Redis server for caching operations. | |
| ttl (int, optional): Time-to-live (TTL) for cached items in seconds. | |
| If provided, it sets the time duration for how long cached | |
| items will remain valid. If not provided, cached items will not | |
| have an automatic expiration. | |
| """ | |
| try: | |
| from redis import Redis | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import redis python package. " | |
| "Please install it with `pip install redis`." | |
| ) | |
| if not isinstance(redis_, Redis): | |
| raise ValueError("Please pass in Redis object.") | |
| self.redis = redis_ | |
| self.ttl = ttl | |
| def _key(self, prompt: str, llm_string: str) -> str: | |
| """Compute key from prompt and llm_string""" | |
| return _hash(prompt + llm_string) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| generations = [] | |
| # Read from a Redis HASH | |
| results = self.redis.hgetall(self._key(prompt, llm_string)) | |
| if results: | |
| for _, text in results.items(): | |
| try: | |
| generations.append(loads(text)) | |
| except Exception: | |
| logger.warning( | |
| "Retrieving a cache value that could not be deserialized " | |
| "properly. This is likely due to the cache being in an " | |
| "older format. Please recreate your cache to avoid this " | |
| "error." | |
| ) | |
| # In a previous life we stored the raw text directly | |
| # in the table, so assume it's in that format. | |
| generations.append(Generation(text=text)) | |
| return generations if generations else None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| for gen in return_val: | |
| if not isinstance(gen, Generation): | |
| raise ValueError( | |
| "RedisCache only supports caching of normal LLM generations, " | |
| f"got {type(gen)}" | |
| ) | |
| # Write to a Redis HASH | |
| key = self._key(prompt, llm_string) | |
| with self.redis.pipeline() as pipe: | |
| pipe.hset( | |
| key, | |
| mapping={ | |
| str(idx): dumps(generation) | |
| for idx, generation in enumerate(return_val) | |
| }, | |
| ) | |
| if self.ttl is not None: | |
| pipe.expire(key, self.ttl) | |
| pipe.execute() | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear cache. If `asynchronous` is True, flush asynchronously.""" | |
| asynchronous = kwargs.get("asynchronous", False) | |
| self.redis.flushdb(asynchronous=asynchronous, **kwargs) | |
| class RedisSemanticCache(BaseCache): | |
| """Cache that uses Redis as a vector-store backend.""" | |
| # TODO - implement a TTL policy in Redis | |
| DEFAULT_SCHEMA = { | |
| "content_key": "prompt", | |
| "text": [ | |
| {"name": "prompt"}, | |
| ], | |
| "extra": [{"name": "return_val"}, {"name": "llm_string"}], | |
| } | |
| def __init__( | |
| self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2 | |
| ): | |
| """Initialize by passing in the `init` GPTCache func | |
| Args: | |
| redis_url (str): URL to connect to Redis. | |
| embedding (Embedding): Embedding provider for semantic encoding and search. | |
| score_threshold (float, 0.2): | |
| Example: | |
| .. code-block:: python | |
| from langchain.globals import set_llm_cache | |
| from langchain.cache import RedisSemanticCache | |
| from langchain.embeddings import OpenAIEmbeddings | |
| set_llm_cache(RedisSemanticCache( | |
| redis_url="redis://localhost:6379", | |
| embedding=OpenAIEmbeddings() | |
| )) | |
| """ | |
| self._cache_dict: Dict[str, RedisVectorstore] = {} | |
| self.redis_url = redis_url | |
| self.embedding = embedding | |
| self.score_threshold = score_threshold | |
| def _index_name(self, llm_string: str) -> str: | |
| hashed_index = _hash(llm_string) | |
| return f"cache:{hashed_index}" | |
| def _get_llm_cache(self, llm_string: str) -> RedisVectorstore: | |
| index_name = self._index_name(llm_string) | |
| # return vectorstore client for the specific llm string | |
| if index_name in self._cache_dict: | |
| return self._cache_dict[index_name] | |
| # create new vectorstore client for the specific llm string | |
| try: | |
| self._cache_dict[index_name] = RedisVectorstore.from_existing_index( | |
| embedding=self.embedding, | |
| index_name=index_name, | |
| redis_url=self.redis_url, | |
| schema=cast(Dict, self.DEFAULT_SCHEMA), | |
| ) | |
| except ValueError: | |
| redis = RedisVectorstore( | |
| embedding=self.embedding, | |
| index_name=index_name, | |
| redis_url=self.redis_url, | |
| index_schema=cast(Dict, self.DEFAULT_SCHEMA), | |
| ) | |
| _embedding = self.embedding.embed_query(text="test") | |
| redis._create_index_if_not_exist(dim=len(_embedding)) | |
| self._cache_dict[index_name] = redis | |
| return self._cache_dict[index_name] | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear semantic cache for a given llm_string.""" | |
| index_name = self._index_name(kwargs["llm_string"]) | |
| if index_name in self._cache_dict: | |
| self._cache_dict[index_name].drop_index( | |
| index_name=index_name, delete_documents=True, redis_url=self.redis_url | |
| ) | |
| del self._cache_dict[index_name] | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| llm_cache = self._get_llm_cache(llm_string) | |
| generations: List = [] | |
| # Read from a Hash | |
| results = llm_cache.similarity_search( | |
| query=prompt, | |
| k=1, | |
| distance_threshold=self.score_threshold, | |
| ) | |
| if results: | |
| for document in results: | |
| try: | |
| generations.extend(loads(document.metadata["return_val"])) | |
| except Exception: | |
| logger.warning( | |
| "Retrieving a cache value that could not be deserialized " | |
| "properly. This is likely due to the cache being in an " | |
| "older format. Please recreate your cache to avoid this " | |
| "error." | |
| ) | |
| # In a previous life we stored the raw text directly | |
| # in the table, so assume it's in that format. | |
| generations.extend( | |
| _load_generations_from_json(document.metadata["return_val"]) | |
| ) | |
| return generations if generations else None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| for gen in return_val: | |
| if not isinstance(gen, Generation): | |
| raise ValueError( | |
| "RedisSemanticCache only supports caching of " | |
| f"normal LLM generations, got {type(gen)}" | |
| ) | |
| llm_cache = self._get_llm_cache(llm_string) | |
| metadata = { | |
| "llm_string": llm_string, | |
| "prompt": prompt, | |
| "return_val": dumps([g for g in return_val]), | |
| } | |
| llm_cache.add_texts(texts=[prompt], metadatas=[metadata]) | |
| class GPTCache(BaseCache): | |
| """Cache that uses GPTCache as a backend.""" | |
| def __init__( | |
| self, | |
| init_func: Union[ | |
| Callable[[Any, str], None], Callable[[Any], None], None | |
| ] = None, | |
| ): | |
| """Initialize by passing in init function (default: `None`). | |
| Args: | |
| init_func (Optional[Callable[[Any], None]]): init `GPTCache` function | |
| (default: `None`) | |
| Example: | |
| .. code-block:: python | |
| # Initialize GPTCache with a custom init function | |
| import gptcache | |
| from gptcache.processor.pre import get_prompt | |
| from gptcache.manager.factory import get_data_manager | |
| from langchain.globals import set_llm_cache | |
| # Avoid multiple caches using the same file, | |
| causing different llm model caches to affect each other | |
| def init_gptcache(cache_obj: gptcache.Cache, llm str): | |
| cache_obj.init( | |
| pre_embedding_func=get_prompt, | |
| data_manager=manager_factory( | |
| manager="map", | |
| data_dir=f"map_cache_{llm}" | |
| ), | |
| ) | |
| set_llm_cache(GPTCache(init_gptcache)) | |
| """ | |
| try: | |
| import gptcache # noqa: F401 | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import gptcache python package. " | |
| "Please install it with `pip install gptcache`." | |
| ) | |
| self.init_gptcache_func: Union[ | |
| Callable[[Any, str], None], Callable[[Any], None], None | |
| ] = init_func | |
| self.gptcache_dict: Dict[str, Any] = {} | |
| def _new_gptcache(self, llm_string: str) -> Any: | |
| """New gptcache object""" | |
| from gptcache import Cache | |
| from gptcache.manager.factory import get_data_manager | |
| from gptcache.processor.pre import get_prompt | |
| _gptcache = Cache() | |
| if self.init_gptcache_func is not None: | |
| sig = inspect.signature(self.init_gptcache_func) | |
| if len(sig.parameters) == 2: | |
| self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg] | |
| else: | |
| self.init_gptcache_func(_gptcache) # type: ignore[call-arg] | |
| else: | |
| _gptcache.init( | |
| pre_embedding_func=get_prompt, | |
| data_manager=get_data_manager(data_path=llm_string), | |
| ) | |
| self.gptcache_dict[llm_string] = _gptcache | |
| return _gptcache | |
| def _get_gptcache(self, llm_string: str) -> Any: | |
| """Get a cache object. | |
| When the corresponding llm model cache does not exist, it will be created.""" | |
| _gptcache = self.gptcache_dict.get(llm_string, None) | |
| if not _gptcache: | |
| _gptcache = self._new_gptcache(llm_string) | |
| return _gptcache | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up the cache data. | |
| First, retrieve the corresponding cache object using the `llm_string` parameter, | |
| and then retrieve the data from the cache based on the `prompt`. | |
| """ | |
| from gptcache.adapter.api import get | |
| _gptcache = self._get_gptcache(llm_string) | |
| res = get(prompt, cache_obj=_gptcache) | |
| if res: | |
| return [ | |
| Generation(**generation_dict) for generation_dict in json.loads(res) | |
| ] | |
| return None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache. | |
| First, retrieve the corresponding cache object using the `llm_string` parameter, | |
| and then store the `prompt` and `return_val` in the cache object. | |
| """ | |
| for gen in return_val: | |
| if not isinstance(gen, Generation): | |
| raise ValueError( | |
| "GPTCache only supports caching of normal LLM generations, " | |
| f"got {type(gen)}" | |
| ) | |
| from gptcache.adapter.api import put | |
| _gptcache = self._get_gptcache(llm_string) | |
| handled_data = json.dumps([generation.dict() for generation in return_val]) | |
| put(prompt, handled_data, cache_obj=_gptcache) | |
| return None | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear cache.""" | |
| from gptcache import Cache | |
| for gptcache_instance in self.gptcache_dict.values(): | |
| gptcache_instance = cast(Cache, gptcache_instance) | |
| gptcache_instance.flush() | |
| self.gptcache_dict.clear() | |
| def _ensure_cache_exists(cache_client: momento.CacheClient, cache_name: str) -> None: | |
| """Create cache if it doesn't exist. | |
| Raises: | |
| SdkException: Momento service or network error | |
| Exception: Unexpected response | |
| """ | |
| from momento.responses import CreateCache | |
| create_cache_response = cache_client.create_cache(cache_name) | |
| if isinstance(create_cache_response, CreateCache.Success) or isinstance( | |
| create_cache_response, CreateCache.CacheAlreadyExists | |
| ): | |
| return None | |
| elif isinstance(create_cache_response, CreateCache.Error): | |
| raise create_cache_response.inner_exception | |
| else: | |
| raise Exception(f"Unexpected response cache creation: {create_cache_response}") | |
| def _validate_ttl(ttl: Optional[timedelta]) -> None: | |
| if ttl is not None and ttl <= timedelta(seconds=0): | |
| raise ValueError(f"ttl must be positive but was {ttl}.") | |
| class MomentoCache(BaseCache): | |
| """Cache that uses Momento as a backend. See https://gomomento.com/""" | |
| def __init__( | |
| self, | |
| cache_client: momento.CacheClient, | |
| cache_name: str, | |
| *, | |
| ttl: Optional[timedelta] = None, | |
| ensure_cache_exists: bool = True, | |
| ): | |
| """Instantiate a prompt cache using Momento as a backend. | |
| Note: to instantiate the cache client passed to MomentoCache, | |
| you must have a Momento account. See https://gomomento.com/. | |
| Args: | |
| cache_client (CacheClient): The Momento cache client. | |
| cache_name (str): The name of the cache to use to store the data. | |
| ttl (Optional[timedelta], optional): The time to live for the cache items. | |
| Defaults to None, ie use the client default TTL. | |
| ensure_cache_exists (bool, optional): Create the cache if it doesn't | |
| exist. Defaults to True. | |
| Raises: | |
| ImportError: Momento python package is not installed. | |
| TypeError: cache_client is not of type momento.CacheClientObject | |
| ValueError: ttl is non-null and non-negative | |
| """ | |
| try: | |
| from momento import CacheClient | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import momento python package. " | |
| "Please install it with `pip install momento`." | |
| ) | |
| if not isinstance(cache_client, CacheClient): | |
| raise TypeError("cache_client must be a momento.CacheClient object.") | |
| _validate_ttl(ttl) | |
| if ensure_cache_exists: | |
| _ensure_cache_exists(cache_client, cache_name) | |
| self.cache_client = cache_client | |
| self.cache_name = cache_name | |
| self.ttl = ttl | |
| def from_client_params( | |
| cls, | |
| cache_name: str, | |
| ttl: timedelta, | |
| *, | |
| configuration: Optional[momento.config.Configuration] = None, | |
| api_key: Optional[str] = None, | |
| auth_token: Optional[str] = None, # for backwards compatibility | |
| **kwargs: Any, | |
| ) -> MomentoCache: | |
| """Construct cache from CacheClient parameters.""" | |
| try: | |
| from momento import CacheClient, Configurations, CredentialProvider | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import momento python package. " | |
| "Please install it with `pip install momento`." | |
| ) | |
| if configuration is None: | |
| configuration = Configurations.Laptop.v1() | |
| # Try checking `MOMENTO_AUTH_TOKEN` first for backwards compatibility | |
| try: | |
| api_key = auth_token or get_from_env("auth_token", "MOMENTO_AUTH_TOKEN") | |
| except ValueError: | |
| api_key = api_key or get_from_env("api_key", "MOMENTO_API_KEY") | |
| credentials = CredentialProvider.from_string(api_key) | |
| cache_client = CacheClient(configuration, credentials, default_ttl=ttl) | |
| return cls(cache_client, cache_name, ttl=ttl, **kwargs) | |
| def __key(self, prompt: str, llm_string: str) -> str: | |
| """Compute cache key from prompt and associated model and settings. | |
| Args: | |
| prompt (str): The prompt run through the language model. | |
| llm_string (str): The language model version and settings. | |
| Returns: | |
| str: The cache key. | |
| """ | |
| return _hash(prompt + llm_string) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Lookup llm generations in cache by prompt and associated model and settings. | |
| Args: | |
| prompt (str): The prompt run through the language model. | |
| llm_string (str): The language model version and settings. | |
| Raises: | |
| SdkException: Momento service or network error | |
| Returns: | |
| Optional[RETURN_VAL_TYPE]: A list of language model generations. | |
| """ | |
| from momento.responses import CacheGet | |
| generations: RETURN_VAL_TYPE = [] | |
| get_response = self.cache_client.get( | |
| self.cache_name, self.__key(prompt, llm_string) | |
| ) | |
| if isinstance(get_response, CacheGet.Hit): | |
| value = get_response.value_string | |
| generations = _load_generations_from_json(value) | |
| elif isinstance(get_response, CacheGet.Miss): | |
| pass | |
| elif isinstance(get_response, CacheGet.Error): | |
| raise get_response.inner_exception | |
| return generations if generations else None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Store llm generations in cache. | |
| Args: | |
| prompt (str): The prompt run through the language model. | |
| llm_string (str): The language model string. | |
| return_val (RETURN_VAL_TYPE): A list of language model generations. | |
| Raises: | |
| SdkException: Momento service or network error | |
| Exception: Unexpected response | |
| """ | |
| for gen in return_val: | |
| if not isinstance(gen, Generation): | |
| raise ValueError( | |
| "Momento only supports caching of normal LLM generations, " | |
| f"got {type(gen)}" | |
| ) | |
| key = self.__key(prompt, llm_string) | |
| value = _dump_generations_to_json(return_val) | |
| set_response = self.cache_client.set(self.cache_name, key, value, self.ttl) | |
| from momento.responses import CacheSet | |
| if isinstance(set_response, CacheSet.Success): | |
| pass | |
| elif isinstance(set_response, CacheSet.Error): | |
| raise set_response.inner_exception | |
| else: | |
| raise Exception(f"Unexpected response: {set_response}") | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear the cache. | |
| Raises: | |
| SdkException: Momento service or network error | |
| """ | |
| from momento.responses import CacheFlush | |
| flush_response = self.cache_client.flush_cache(self.cache_name) | |
| if isinstance(flush_response, CacheFlush.Success): | |
| pass | |
| elif isinstance(flush_response, CacheFlush.Error): | |
| raise flush_response.inner_exception | |
| CASSANDRA_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_cache" | |
| CASSANDRA_CACHE_DEFAULT_TTL_SECONDS = None | |
| class CassandraCache(BaseCache): | |
| """ | |
| Cache that uses Cassandra / Astra DB as a backend. | |
| It uses a single Cassandra table. | |
| The lookup keys (which get to form the primary key) are: | |
| - prompt, a string | |
| - llm_string, a deterministic str representation of the model parameters. | |
| (needed to prevent collisions same-prompt-different-model collisions) | |
| """ | |
| def __init__( | |
| self, | |
| session: Optional[CassandraSession] = None, | |
| keyspace: Optional[str] = None, | |
| table_name: str = CASSANDRA_CACHE_DEFAULT_TABLE_NAME, | |
| ttl_seconds: Optional[int] = CASSANDRA_CACHE_DEFAULT_TTL_SECONDS, | |
| skip_provisioning: bool = False, | |
| ): | |
| """ | |
| Initialize with a ready session and a keyspace name. | |
| Args: | |
| session (cassandra.cluster.Session): an open Cassandra session | |
| keyspace (str): the keyspace to use for storing the cache | |
| table_name (str): name of the Cassandra table to use as cache | |
| ttl_seconds (optional int): time-to-live for cache entries | |
| (default: None, i.e. forever) | |
| """ | |
| try: | |
| from cassio.table import ElasticCassandraTable | |
| except (ImportError, ModuleNotFoundError): | |
| raise ValueError( | |
| "Could not import cassio python package. " | |
| "Please install it with `pip install cassio`." | |
| ) | |
| self.session = session | |
| self.keyspace = keyspace | |
| self.table_name = table_name | |
| self.ttl_seconds = ttl_seconds | |
| self.kv_cache = ElasticCassandraTable( | |
| session=self.session, | |
| keyspace=self.keyspace, | |
| table=self.table_name, | |
| keys=["llm_string", "prompt"], | |
| primary_key_type=["TEXT", "TEXT"], | |
| ttl_seconds=self.ttl_seconds, | |
| skip_provisioning=skip_provisioning, | |
| ) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| item = self.kv_cache.get( | |
| llm_string=_hash(llm_string), | |
| prompt=_hash(prompt), | |
| ) | |
| if item is not None: | |
| generations = _loads_generations(item["body_blob"]) | |
| # this protects against malformed cached items: | |
| if generations is not None: | |
| return generations | |
| else: | |
| return None | |
| else: | |
| return None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| blob = _dumps_generations(return_val) | |
| self.kv_cache.put( | |
| llm_string=_hash(llm_string), | |
| prompt=_hash(prompt), | |
| body_blob=blob, | |
| ) | |
| def delete_through_llm( | |
| self, prompt: str, llm: LLM, stop: Optional[List[str]] = None | |
| ) -> None: | |
| """ | |
| A wrapper around `delete` with the LLM being passed. | |
| In case the llm(prompt) calls have a `stop` param, you should pass it here | |
| """ | |
| llm_string = get_prompts( | |
| {**llm.dict(), **{"stop": stop}}, | |
| [], | |
| )[1] | |
| return self.delete(prompt, llm_string=llm_string) | |
| def delete(self, prompt: str, llm_string: str) -> None: | |
| """Evict from cache if there's an entry.""" | |
| return self.kv_cache.delete( | |
| llm_string=_hash(llm_string), | |
| prompt=_hash(prompt), | |
| ) | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear cache. This is for all LLMs at once.""" | |
| self.kv_cache.clear() | |
| CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC = "dot" | |
| CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD = 0.85 | |
| CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME = "langchain_llm_semantic_cache" | |
| CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS = None | |
| CASSANDRA_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 | |
| class CassandraSemanticCache(BaseCache): | |
| """ | |
| Cache that uses Cassandra as a vector-store backend for semantic | |
| (i.e. similarity-based) lookup. | |
| It uses a single (vector) Cassandra table and stores, in principle, | |
| cached values from several LLMs, so the LLM's llm_string is part | |
| of the rows' primary keys. | |
| The similarity is based on one of several distance metrics (default: "dot"). | |
| If choosing another metric, the default threshold is to be re-tuned accordingly. | |
| """ | |
| def __init__( | |
| self, | |
| session: Optional[CassandraSession], | |
| keyspace: Optional[str], | |
| embedding: Embeddings, | |
| table_name: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TABLE_NAME, | |
| distance_metric: str = CASSANDRA_SEMANTIC_CACHE_DEFAULT_DISTANCE_METRIC, | |
| score_threshold: float = CASSANDRA_SEMANTIC_CACHE_DEFAULT_SCORE_THRESHOLD, | |
| ttl_seconds: Optional[int] = CASSANDRA_SEMANTIC_CACHE_DEFAULT_TTL_SECONDS, | |
| skip_provisioning: bool = False, | |
| ): | |
| """ | |
| Initialize the cache with all relevant parameters. | |
| Args: | |
| session (cassandra.cluster.Session): an open Cassandra session | |
| keyspace (str): the keyspace to use for storing the cache | |
| embedding (Embedding): Embedding provider for semantic | |
| encoding and search. | |
| table_name (str): name of the Cassandra (vector) table | |
| to use as cache | |
| distance_metric (str, 'dot'): which measure to adopt for | |
| similarity searches | |
| score_threshold (optional float): numeric value to use as | |
| cutoff for the similarity searches | |
| ttl_seconds (optional int): time-to-live for cache entries | |
| (default: None, i.e. forever) | |
| The default score threshold is tuned to the default metric. | |
| Tune it carefully yourself if switching to another distance metric. | |
| """ | |
| try: | |
| from cassio.table import MetadataVectorCassandraTable | |
| except (ImportError, ModuleNotFoundError): | |
| raise ValueError( | |
| "Could not import cassio python package. " | |
| "Please install it with `pip install cassio`." | |
| ) | |
| self.session = session | |
| self.keyspace = keyspace | |
| self.embedding = embedding | |
| self.table_name = table_name | |
| self.distance_metric = distance_metric | |
| self.score_threshold = score_threshold | |
| self.ttl_seconds = ttl_seconds | |
| # The contract for this class has separate lookup and update: | |
| # in order to spare some embedding calculations we cache them between | |
| # the two calls. | |
| # Note: each instance of this class has its own `_get_embedding` with | |
| # its own lru. | |
| def _cache_embedding(text: str) -> List[float]: | |
| return self.embedding.embed_query(text=text) | |
| self._get_embedding = _cache_embedding | |
| self.embedding_dimension = self._get_embedding_dimension() | |
| self.table = MetadataVectorCassandraTable( | |
| session=self.session, | |
| keyspace=self.keyspace, | |
| table=self.table_name, | |
| primary_key_type=["TEXT"], | |
| vector_dimension=self.embedding_dimension, | |
| ttl_seconds=self.ttl_seconds, | |
| metadata_indexing=("allow", {"_llm_string_hash"}), | |
| skip_provisioning=skip_provisioning, | |
| ) | |
| def _get_embedding_dimension(self) -> int: | |
| return len(self._get_embedding(text="This is a sample sentence.")) | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| embedding_vector = self._get_embedding(text=prompt) | |
| llm_string_hash = _hash(llm_string) | |
| body = _dumps_generations(return_val) | |
| metadata = { | |
| "_prompt": prompt, | |
| "_llm_string_hash": llm_string_hash, | |
| } | |
| row_id = f"{_hash(prompt)}-{llm_string_hash}" | |
| # | |
| self.table.put( | |
| body_blob=body, | |
| vector=embedding_vector, | |
| row_id=row_id, | |
| metadata=metadata, | |
| ) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| hit_with_id = self.lookup_with_id(prompt, llm_string) | |
| if hit_with_id is not None: | |
| return hit_with_id[1] | |
| else: | |
| return None | |
| def lookup_with_id( | |
| self, prompt: str, llm_string: str | |
| ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: | |
| """ | |
| Look up based on prompt and llm_string. | |
| If there are hits, return (document_id, cached_entry) | |
| """ | |
| prompt_embedding: List[float] = self._get_embedding(text=prompt) | |
| hits = list( | |
| self.table.metric_ann_search( | |
| vector=prompt_embedding, | |
| metadata={"_llm_string_hash": _hash(llm_string)}, | |
| n=1, | |
| metric=self.distance_metric, | |
| metric_threshold=self.score_threshold, | |
| ) | |
| ) | |
| if hits: | |
| hit = hits[0] | |
| generations = _loads_generations(hit["body_blob"]) | |
| if generations is not None: | |
| # this protects against malformed cached items: | |
| return ( | |
| hit["row_id"], | |
| generations, | |
| ) | |
| else: | |
| return None | |
| else: | |
| return None | |
| def lookup_with_id_through_llm( | |
| self, prompt: str, llm: LLM, stop: Optional[List[str]] = None | |
| ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: | |
| llm_string = get_prompts( | |
| {**llm.dict(), **{"stop": stop}}, | |
| [], | |
| )[1] | |
| return self.lookup_with_id(prompt, llm_string=llm_string) | |
| def delete_by_document_id(self, document_id: str) -> None: | |
| """ | |
| Given this is a "similarity search" cache, an invalidation pattern | |
| that makes sense is first a lookup to get an ID, and then deleting | |
| with that ID. This is for the second step. | |
| """ | |
| self.table.delete(row_id=document_id) | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear the *whole* semantic cache.""" | |
| self.table.clear() | |
| class FullMd5LLMCache(Base): # type: ignore | |
| """SQLite table for full LLM Cache (all generations).""" | |
| __tablename__ = "full_md5_llm_cache" | |
| id = Column(String, primary_key=True) | |
| prompt_md5 = Column(String, index=True) | |
| llm = Column(String, index=True) | |
| idx = Column(Integer, index=True) | |
| prompt = Column(String) | |
| response = Column(String) | |
| class SQLAlchemyMd5Cache(BaseCache): | |
| """Cache that uses SQAlchemy as a backend.""" | |
| def __init__( | |
| self, engine: Engine, cache_schema: Type[FullMd5LLMCache] = FullMd5LLMCache | |
| ): | |
| """Initialize by creating all tables.""" | |
| self.engine = engine | |
| self.cache_schema = cache_schema | |
| self.cache_schema.metadata.create_all(self.engine) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| rows = self._search_rows(prompt, llm_string) | |
| if rows: | |
| return [loads(row[0]) for row in rows] | |
| return None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update based on prompt and llm_string.""" | |
| self._delete_previous(prompt, llm_string) | |
| prompt_md5 = self.get_md5(prompt) | |
| items = [ | |
| self.cache_schema( | |
| id=str(uuid.uuid1()), | |
| prompt=prompt, | |
| prompt_md5=prompt_md5, | |
| llm=llm_string, | |
| response=dumps(gen), | |
| idx=i, | |
| ) | |
| for i, gen in enumerate(return_val) | |
| ] | |
| with Session(self.engine) as session, session.begin(): | |
| for item in items: | |
| session.merge(item) | |
| def _delete_previous(self, prompt: str, llm_string: str) -> None: | |
| stmt = ( | |
| select(self.cache_schema.response) | |
| .where(self.cache_schema.prompt_md5 == self.get_md5(prompt)) # type: ignore | |
| .where(self.cache_schema.llm == llm_string) | |
| .where(self.cache_schema.prompt == prompt) | |
| .order_by(self.cache_schema.idx) | |
| ) | |
| with Session(self.engine) as session, session.begin(): | |
| rows = session.execute(stmt).fetchall() | |
| for item in rows: | |
| session.delete(item) | |
| def _search_rows(self, prompt: str, llm_string: str) -> List[Row]: | |
| prompt_pd5 = self.get_md5(prompt) | |
| stmt = ( | |
| select(self.cache_schema.response) | |
| .where(self.cache_schema.prompt_md5 == prompt_pd5) # type: ignore | |
| .where(self.cache_schema.llm == llm_string) | |
| .where(self.cache_schema.prompt == prompt) | |
| .order_by(self.cache_schema.idx) | |
| ) | |
| with Session(self.engine) as session: | |
| return session.execute(stmt).fetchall() | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear cache.""" | |
| with Session(self.engine) as session: | |
| session.execute(self.cache_schema.delete()) | |
| def get_md5(input_string: str) -> str: | |
| return hashlib.md5(input_string.encode()).hexdigest() | |
| ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_cache" | |
| class AstraDBCache(BaseCache): | |
| """ | |
| Cache that uses Astra DB as a backend. | |
| It uses a single collection as a kv store | |
| The lookup keys, combined in the _id of the documents, are: | |
| - prompt, a string | |
| - llm_string, a deterministic str representation of the model parameters. | |
| (needed to prevent same-prompt-different-model collisions) | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, | |
| token: Optional[str] = None, | |
| api_endpoint: Optional[str] = None, | |
| astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed | |
| namespace: Optional[str] = None, | |
| ): | |
| """ | |
| Create an AstraDB cache using a collection for storage. | |
| Args (only keyword-arguments accepted): | |
| collection_name (str): name of the Astra DB collection to create/use. | |
| token (Optional[str]): API token for Astra DB usage. | |
| api_endpoint (Optional[str]): full URL to the API endpoint, | |
| such as "https://<DB-ID>-us-east1.apps.astra.datastax.com". | |
| astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, | |
| you can pass an already-created 'astrapy.db.AstraDB' instance. | |
| namespace (Optional[str]): namespace (aka keyspace) where the | |
| collection is created. Defaults to the database's "default namespace". | |
| """ | |
| try: | |
| from astrapy.db import ( | |
| AstraDB as LibAstraDB, | |
| ) | |
| except (ImportError, ModuleNotFoundError): | |
| raise ImportError( | |
| "Could not import a recent astrapy python package. " | |
| "Please install it with `pip install --upgrade astrapy`." | |
| ) | |
| # Conflicting-arg checks: | |
| if astra_db_client is not None: | |
| if token is not None or api_endpoint is not None: | |
| raise ValueError( | |
| "You cannot pass 'astra_db_client' to AstraDB if passing " | |
| "'token' and 'api_endpoint'." | |
| ) | |
| self.collection_name = collection_name | |
| self.token = token | |
| self.api_endpoint = api_endpoint | |
| self.namespace = namespace | |
| if astra_db_client is not None: | |
| self.astra_db = astra_db_client | |
| else: | |
| self.astra_db = LibAstraDB( | |
| token=self.token, | |
| api_endpoint=self.api_endpoint, | |
| namespace=self.namespace, | |
| ) | |
| self.collection = self.astra_db.create_collection( | |
| collection_name=self.collection_name, | |
| ) | |
| def _make_id(prompt: str, llm_string: str) -> str: | |
| return f"{_hash(prompt)}#{_hash(llm_string)}" | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| doc_id = self._make_id(prompt, llm_string) | |
| item = self.collection.find_one( | |
| filter={ | |
| "_id": doc_id, | |
| }, | |
| projection={ | |
| "body_blob": 1, | |
| }, | |
| )["data"]["document"] | |
| if item is not None: | |
| generations = _loads_generations(item["body_blob"]) | |
| # this protects against malformed cached items: | |
| if generations is not None: | |
| return generations | |
| else: | |
| return None | |
| else: | |
| return None | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| doc_id = self._make_id(prompt, llm_string) | |
| blob = _dumps_generations(return_val) | |
| self.collection.upsert( | |
| { | |
| "_id": doc_id, | |
| "body_blob": blob, | |
| }, | |
| ) | |
| def delete_through_llm( | |
| self, prompt: str, llm: LLM, stop: Optional[List[str]] = None | |
| ) -> None: | |
| """ | |
| A wrapper around `delete` with the LLM being passed. | |
| In case the llm(prompt) calls have a `stop` param, you should pass it here | |
| """ | |
| llm_string = get_prompts( | |
| {**llm.dict(), **{"stop": stop}}, | |
| [], | |
| )[1] | |
| return self.delete(prompt, llm_string=llm_string) | |
| def delete(self, prompt: str, llm_string: str) -> None: | |
| """Evict from cache if there's an entry.""" | |
| doc_id = self._make_id(prompt, llm_string) | |
| return self.collection.delete_one(doc_id) | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear cache. This is for all LLMs at once.""" | |
| self.astra_db.truncate_collection(self.collection_name) | |
| ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85 | |
| ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME = "langchain_astradb_semantic_cache" | |
| ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 | |
| class AstraDBSemanticCache(BaseCache): | |
| """ | |
| Cache that uses Astra DB as a vector-store backend for semantic | |
| (i.e. similarity-based) lookup. | |
| It uses a single (vector) collection and can store | |
| cached values from several LLMs, so the LLM's 'llm_string' is stored | |
| in the document metadata. | |
| You can choose the preferred similarity (or use the API default) -- | |
| remember the threshold might require metric-dependend tuning. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| collection_name: str = ASTRA_DB_CACHE_DEFAULT_COLLECTION_NAME, | |
| token: Optional[str] = None, | |
| api_endpoint: Optional[str] = None, | |
| astra_db_client: Optional[Any] = None, # 'astrapy.db.AstraDB' if passed | |
| namespace: Optional[str] = None, | |
| embedding: Embeddings, | |
| metric: Optional[str] = None, | |
| similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, | |
| ): | |
| """ | |
| Initialize the cache with all relevant parameters. | |
| Args: | |
| collection_name (str): name of the Astra DB collection to create/use. | |
| token (Optional[str]): API token for Astra DB usage. | |
| api_endpoint (Optional[str]): full URL to the API endpoint, | |
| such as "https://<DB-ID>-us-east1.apps.astra.datastax.com". | |
| astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, | |
| you can pass an already-created 'astrapy.db.AstraDB' instance. | |
| namespace (Optional[str]): namespace (aka keyspace) where the | |
| collection is created. Defaults to the database's "default namespace". | |
| embedding (Embedding): Embedding provider for semantic | |
| encoding and search. | |
| metric: the function to use for evaluating similarity of text embeddings. | |
| Defaults to 'cosine' (alternatives: 'euclidean', 'dot_product') | |
| similarity_threshold (float, optional): the minimum similarity | |
| for accepting a (semantic-search) match. | |
| The default score threshold is tuned to the default metric. | |
| Tune it carefully yourself if switching to another distance metric. | |
| """ | |
| try: | |
| from astrapy.db import ( | |
| AstraDB as LibAstraDB, | |
| ) | |
| except (ImportError, ModuleNotFoundError): | |
| raise ImportError( | |
| "Could not import a recent astrapy python package. " | |
| "Please install it with `pip install --upgrade astrapy`." | |
| ) | |
| # Conflicting-arg checks: | |
| if astra_db_client is not None: | |
| if token is not None or api_endpoint is not None: | |
| raise ValueError( | |
| "You cannot pass 'astra_db_client' to AstraDB if passing " | |
| "'token' and 'api_endpoint'." | |
| ) | |
| self.embedding = embedding | |
| self.metric = metric | |
| self.similarity_threshold = similarity_threshold | |
| # The contract for this class has separate lookup and update: | |
| # in order to spare some embedding calculations we cache them between | |
| # the two calls. | |
| # Note: each instance of this class has its own `_get_embedding` with | |
| # its own lru. | |
| def _cache_embedding(text: str) -> List[float]: | |
| return self.embedding.embed_query(text=text) | |
| self._get_embedding = _cache_embedding | |
| self.embedding_dimension = self._get_embedding_dimension() | |
| self.collection_name = collection_name | |
| self.token = token | |
| self.api_endpoint = api_endpoint | |
| self.namespace = namespace | |
| if astra_db_client is not None: | |
| self.astra_db = astra_db_client | |
| else: | |
| self.astra_db = LibAstraDB( | |
| token=self.token, | |
| api_endpoint=self.api_endpoint, | |
| namespace=self.namespace, | |
| ) | |
| self.collection = self.astra_db.create_collection( | |
| collection_name=self.collection_name, | |
| dimension=self.embedding_dimension, | |
| metric=self.metric, | |
| ) | |
| def _get_embedding_dimension(self) -> int: | |
| return len(self._get_embedding(text="This is a sample sentence.")) | |
| def _make_id(prompt: str, llm_string: str) -> str: | |
| return f"{_hash(prompt)}#{_hash(llm_string)}" | |
| def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: | |
| """Update cache based on prompt and llm_string.""" | |
| doc_id = self._make_id(prompt, llm_string) | |
| llm_string_hash = _hash(llm_string) | |
| embedding_vector = self._get_embedding(text=prompt) | |
| body = _dumps_generations(return_val) | |
| # | |
| self.collection.upsert( | |
| { | |
| "_id": doc_id, | |
| "body_blob": body, | |
| "llm_string_hash": llm_string_hash, | |
| "$vector": embedding_vector, | |
| } | |
| ) | |
| def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: | |
| """Look up based on prompt and llm_string.""" | |
| hit_with_id = self.lookup_with_id(prompt, llm_string) | |
| if hit_with_id is not None: | |
| return hit_with_id[1] | |
| else: | |
| return None | |
| def lookup_with_id( | |
| self, prompt: str, llm_string: str | |
| ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: | |
| """ | |
| Look up based on prompt and llm_string. | |
| If there are hits, return (document_id, cached_entry) for the top hit | |
| """ | |
| prompt_embedding: List[float] = self._get_embedding(text=prompt) | |
| llm_string_hash = _hash(llm_string) | |
| hit = self.collection.vector_find_one( | |
| vector=prompt_embedding, | |
| filter={ | |
| "llm_string_hash": llm_string_hash, | |
| }, | |
| fields=["body_blob", "_id"], | |
| include_similarity=True, | |
| ) | |
| if hit is None or hit["$similarity"] < self.similarity_threshold: | |
| return None | |
| else: | |
| generations = _loads_generations(hit["body_blob"]) | |
| if generations is not None: | |
| # this protects against malformed cached items: | |
| return (hit["_id"], generations) | |
| else: | |
| return None | |
| def lookup_with_id_through_llm( | |
| self, prompt: str, llm: LLM, stop: Optional[List[str]] = None | |
| ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: | |
| llm_string = get_prompts( | |
| {**llm.dict(), **{"stop": stop}}, | |
| [], | |
| )[1] | |
| return self.lookup_with_id(prompt, llm_string=llm_string) | |
| def delete_by_document_id(self, document_id: str) -> None: | |
| """ | |
| Given this is a "similarity search" cache, an invalidation pattern | |
| that makes sense is first a lookup to get an ID, and then deleting | |
| with that ID. This is for the second step. | |
| """ | |
| self.collection.delete_one(document_id) | |
| def clear(self, **kwargs: Any) -> None: | |
| """Clear the *whole* semantic cache.""" | |
| self.astra_db.truncate_collection(self.collection_name) | |