Spaces:
Paused
Paused
| """ | |
| Send logs to Argilla for annotation | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| import random | |
| import types | |
| from typing import Any, Dict, List, Optional | |
| import httpx | |
| from pydantic import BaseModel # type: ignore | |
| import litellm | |
| from litellm._logging import verbose_logger | |
| from litellm.integrations.custom_batch_logger import CustomBatchLogger | |
| from litellm.integrations.custom_logger import CustomLogger | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| get_async_httpx_client, | |
| httpxSpecialProvider, | |
| ) | |
| from litellm.types.integrations.argilla import ( | |
| SUPPORTED_PAYLOAD_FIELDS, | |
| ArgillaCredentialsObject, | |
| ArgillaItem, | |
| ) | |
| from litellm.types.utils import StandardLoggingPayload | |
| def is_serializable(value): | |
| non_serializable_types = ( | |
| types.CoroutineType, | |
| types.FunctionType, | |
| types.GeneratorType, | |
| BaseModel, | |
| ) | |
| return not isinstance(value, non_serializable_types) | |
| class ArgillaLogger(CustomBatchLogger): | |
| def __init__( | |
| self, | |
| argilla_api_key: Optional[str] = None, | |
| argilla_dataset_name: Optional[str] = None, | |
| argilla_base_url: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| if litellm.argilla_transformation_object is None: | |
| raise Exception( | |
| "'litellm.argilla_transformation_object' is required, to log your payload to Argilla." | |
| ) | |
| self.validate_argilla_transformation_object( | |
| litellm.argilla_transformation_object | |
| ) | |
| self.argilla_transformation_object = litellm.argilla_transformation_object | |
| self.default_credentials = self.get_credentials_from_env( | |
| argilla_api_key=argilla_api_key, | |
| argilla_dataset_name=argilla_dataset_name, | |
| argilla_base_url=argilla_base_url, | |
| ) | |
| self.sampling_rate: float = ( | |
| float(os.getenv("ARGILLA_SAMPLING_RATE")) # type: ignore | |
| if os.getenv("ARGILLA_SAMPLING_RATE") is not None | |
| and os.getenv("ARGILLA_SAMPLING_RATE").strip().isdigit() # type: ignore | |
| else 1.0 | |
| ) | |
| self.async_httpx_client = get_async_httpx_client( | |
| llm_provider=httpxSpecialProvider.LoggingCallback | |
| ) | |
| _batch_size = ( | |
| os.getenv("ARGILLA_BATCH_SIZE", None) or litellm.argilla_batch_size | |
| ) | |
| if _batch_size: | |
| self.batch_size = int(_batch_size) | |
| asyncio.create_task(self.periodic_flush()) | |
| self.flush_lock = asyncio.Lock() | |
| super().__init__(**kwargs, flush_lock=self.flush_lock) | |
| def validate_argilla_transformation_object( | |
| self, argilla_transformation_object: Dict[str, Any] | |
| ): | |
| if not isinstance(argilla_transformation_object, dict): | |
| raise Exception( | |
| "'argilla_transformation_object' must be a dictionary, to log your payload to Argilla." | |
| ) | |
| for v in argilla_transformation_object.values(): | |
| if v not in SUPPORTED_PAYLOAD_FIELDS: | |
| raise Exception( | |
| f"All values in argilla_transformation_object must be a key in SUPPORTED_PAYLOAD_FIELDS, {v} is not a valid key." | |
| ) | |
| def get_credentials_from_env( | |
| self, | |
| argilla_api_key: Optional[str], | |
| argilla_dataset_name: Optional[str], | |
| argilla_base_url: Optional[str], | |
| ) -> ArgillaCredentialsObject: | |
| _credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY") | |
| if _credentials_api_key is None: | |
| raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.") | |
| _credentials_base_url = ( | |
| argilla_base_url | |
| or os.getenv("ARGILLA_BASE_URL") | |
| or "http://localhost:6900/" | |
| ) | |
| if _credentials_base_url is None: | |
| raise Exception( | |
| "Invalid Argilla Base URL given. _credentials_base_url=None." | |
| ) | |
| _credentials_dataset_name = ( | |
| argilla_dataset_name | |
| or os.getenv("ARGILLA_DATASET_NAME") | |
| or "litellm-completion" | |
| ) | |
| if _credentials_dataset_name is None: | |
| raise Exception("Invalid Argilla Dataset give. Value=None.") | |
| else: | |
| dataset_response = litellm.module_level_client.get( | |
| url=f"{_credentials_base_url}/api/v1/me/datasets?name={_credentials_dataset_name}", | |
| headers={"X-Argilla-Api-Key": _credentials_api_key}, | |
| ) | |
| json_response = dataset_response.json() | |
| if ( | |
| "items" in json_response | |
| and isinstance(json_response["items"], list) | |
| and len(json_response["items"]) > 0 | |
| ): | |
| _credentials_dataset_name = json_response["items"][0]["id"] | |
| return ArgillaCredentialsObject( | |
| ARGILLA_API_KEY=_credentials_api_key, | |
| ARGILLA_BASE_URL=_credentials_base_url, | |
| ARGILLA_DATASET_NAME=_credentials_dataset_name, | |
| ) | |
| def get_chat_messages( | |
| self, payload: StandardLoggingPayload | |
| ) -> List[Dict[str, Any]]: | |
| payload_messages = payload.get("messages", None) | |
| if payload_messages is None: | |
| raise Exception("No chat messages found in payload.") | |
| if ( | |
| isinstance(payload_messages, list) | |
| and len(payload_messages) > 0 | |
| and isinstance(payload_messages[0], dict) | |
| ): | |
| return payload_messages | |
| elif isinstance(payload_messages, dict): | |
| return [payload_messages] | |
| else: | |
| raise Exception(f"Invalid chat messages format: {payload_messages}") | |
| def get_str_response(self, payload: StandardLoggingPayload) -> str: | |
| response = payload["response"] | |
| if response is None: | |
| raise Exception("No response found in payload.") | |
| if isinstance(response, str): | |
| return response | |
| elif isinstance(response, dict): | |
| return ( | |
| response.get("choices", [{}])[0].get("message", {}).get("content", "") | |
| ) | |
| else: | |
| raise Exception(f"Invalid response format: {response}") | |
| def _prepare_log_data( | |
| self, kwargs, response_obj, start_time, end_time | |
| ) -> Optional[ArgillaItem]: | |
| try: | |
| # Ensure everything in the payload is converted to str | |
| payload: Optional[StandardLoggingPayload] = kwargs.get( | |
| "standard_logging_object", None | |
| ) | |
| if payload is None: | |
| raise Exception("Error logging request payload. Payload=none.") | |
| argilla_message = self.get_chat_messages(payload) | |
| argilla_response = self.get_str_response(payload) | |
| argilla_item: ArgillaItem = {"fields": {}} | |
| for k, v in self.argilla_transformation_object.items(): | |
| if v == "messages": | |
| argilla_item["fields"][k] = argilla_message | |
| elif v == "response": | |
| argilla_item["fields"][k] = argilla_response | |
| else: | |
| argilla_item["fields"][k] = payload.get(v, None) | |
| return argilla_item | |
| except Exception: | |
| raise | |
| def _send_batch(self): | |
| if not self.log_queue: | |
| return | |
| argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] | |
| argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] | |
| url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" | |
| argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] | |
| headers = {"X-Argilla-Api-Key": argilla_api_key} | |
| try: | |
| response = litellm.module_level_client.post( | |
| url=url, | |
| json=self.log_queue, | |
| headers=headers, | |
| ) | |
| if response.status_code >= 300: | |
| verbose_logger.error( | |
| f"Argilla Error: {response.status_code} - {response.text}" | |
| ) | |
| else: | |
| verbose_logger.debug( | |
| f"Batch of {len(self.log_queue)} runs successfully created" | |
| ) | |
| self.log_queue.clear() | |
| except Exception: | |
| verbose_logger.exception("Argilla Layer Error - Error sending batch.") | |
| def log_success_event(self, kwargs, response_obj, start_time, end_time): | |
| try: | |
| sampling_rate = ( | |
| float(os.getenv("LANGSMITH_SAMPLING_RATE")) # type: ignore | |
| if os.getenv("LANGSMITH_SAMPLING_RATE") is not None | |
| and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit() # type: ignore | |
| else 1.0 | |
| ) | |
| random_sample = random.random() | |
| if random_sample > sampling_rate: | |
| verbose_logger.info( | |
| "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( | |
| sampling_rate, random_sample | |
| ) | |
| ) | |
| return # Skip logging | |
| verbose_logger.debug( | |
| "Langsmith Sync Layer Logging - kwargs: %s, response_obj: %s", | |
| kwargs, | |
| response_obj, | |
| ) | |
| data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) | |
| if data is None: | |
| return | |
| self.log_queue.append(data) | |
| verbose_logger.debug( | |
| f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..." | |
| ) | |
| if len(self.log_queue) >= self.batch_size: | |
| self._send_batch() | |
| except Exception: | |
| verbose_logger.exception("Langsmith Layer Error - log_success_event error") | |
| async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
| try: | |
| sampling_rate = self.sampling_rate | |
| random_sample = random.random() | |
| if random_sample > sampling_rate: | |
| verbose_logger.info( | |
| "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( | |
| sampling_rate, random_sample | |
| ) | |
| ) | |
| return # Skip logging | |
| verbose_logger.debug( | |
| "Langsmith Async Layer Logging - kwargs: %s, response_obj: %s", | |
| kwargs, | |
| response_obj, | |
| ) | |
| payload: Optional[StandardLoggingPayload] = kwargs.get( | |
| "standard_logging_object", None | |
| ) | |
| data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) | |
| ## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING | |
| for callback in litellm.callbacks: | |
| if isinstance(callback, CustomLogger): | |
| try: | |
| if data is None: | |
| break | |
| data = await callback.async_dataset_hook(data, payload) | |
| except NotImplementedError: | |
| pass | |
| if data is None: | |
| return | |
| self.log_queue.append(data) | |
| verbose_logger.debug( | |
| "Langsmith logging: queue length %s, batch size %s", | |
| len(self.log_queue), | |
| self.batch_size, | |
| ) | |
| if len(self.log_queue) >= self.batch_size: | |
| await self.flush_queue() | |
| except Exception: | |
| verbose_logger.exception( | |
| "Argilla Layer Error - error logging async success event." | |
| ) | |
| async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
| sampling_rate = self.sampling_rate | |
| random_sample = random.random() | |
| if random_sample > sampling_rate: | |
| verbose_logger.info( | |
| "Skipping Langsmith logging. Sampling rate={}, random_sample={}".format( | |
| sampling_rate, random_sample | |
| ) | |
| ) | |
| return # Skip logging | |
| verbose_logger.info("Langsmith Failure Event Logging!") | |
| try: | |
| data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) | |
| self.log_queue.append(data) | |
| verbose_logger.debug( | |
| "Langsmith logging: queue length %s, batch size %s", | |
| len(self.log_queue), | |
| self.batch_size, | |
| ) | |
| if len(self.log_queue) >= self.batch_size: | |
| await self.flush_queue() | |
| except Exception: | |
| verbose_logger.exception( | |
| "Langsmith Layer Error - error logging async failure event." | |
| ) | |
| async def async_send_batch(self): | |
| """ | |
| sends runs to /batch endpoint | |
| Sends runs from self.log_queue | |
| Returns: None | |
| Raises: Does not raise an exception, will only verbose_logger.exception() | |
| """ | |
| if not self.log_queue: | |
| return | |
| argilla_api_base = self.default_credentials["ARGILLA_BASE_URL"] | |
| argilla_dataset_name = self.default_credentials["ARGILLA_DATASET_NAME"] | |
| url = f"{argilla_api_base}/api/v1/datasets/{argilla_dataset_name}/records/bulk" | |
| argilla_api_key = self.default_credentials["ARGILLA_API_KEY"] | |
| headers = {"X-Argilla-Api-Key": argilla_api_key} | |
| try: | |
| response = await self.async_httpx_client.put( | |
| url=url, | |
| data=json.dumps( | |
| { | |
| "items": self.log_queue, | |
| } | |
| ), | |
| headers=headers, | |
| timeout=60000, | |
| ) | |
| response.raise_for_status() | |
| if response.status_code >= 300: | |
| verbose_logger.error( | |
| f"Argilla Error: {response.status_code} - {response.text}" | |
| ) | |
| else: | |
| verbose_logger.debug( | |
| "Batch of %s runs successfully created", len(self.log_queue) | |
| ) | |
| except httpx.HTTPStatusError: | |
| verbose_logger.exception("Argilla HTTP Error") | |
| except Exception: | |
| verbose_logger.exception("Argilla Layer Error") | |