Spaces:
Paused
Paused
| from typing import Any, Dict, List, Literal, Optional, Union | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.caching.caching import DualCache | |
| from litellm.integrations.custom_guardrail import CustomGuardrail | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata | |
| class myCustomGuardrail(CustomGuardrail): | |
| def __init__( | |
| self, | |
| **kwargs, | |
| ): | |
| # store kwargs as optional_params | |
| self.optional_params = kwargs | |
| super().__init__(**kwargs) | |
| async def async_pre_call_hook( | |
| self, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| cache: DualCache, | |
| data: dict, | |
| call_type: Literal[ | |
| "completion", | |
| "text_completion", | |
| "embeddings", | |
| "image_generation", | |
| "moderation", | |
| "audio_transcription", | |
| "pass_through_endpoint", | |
| "rerank", | |
| ], | |
| ) -> Optional[Union[Exception, str, dict]]: | |
| """ | |
| Runs before the LLM API call | |
| Runs on only Input | |
| Use this if you want to MODIFY the input | |
| """ | |
| # In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM | |
| _messages = data.get("messages") | |
| if _messages: | |
| for message in _messages: | |
| _content = message.get("content") | |
| if isinstance(_content, str): | |
| if "litellm" in _content.lower(): | |
| _content = _content.replace("litellm", "********") | |
| message["content"] = _content | |
| verbose_proxy_logger.debug( | |
| "async_pre_call_hook: Message after masking %s", _messages | |
| ) | |
| return data | |
| async def async_moderation_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| call_type: Literal[ | |
| "completion", | |
| "embeddings", | |
| "image_generation", | |
| "moderation", | |
| "audio_transcription", | |
| "responses", | |
| ], | |
| ): | |
| """ | |
| Runs in parallel to LLM API call | |
| Runs on only Input | |
| This can NOT modify the input, only used to reject or accept a call before going to LLM API | |
| """ | |
| # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call | |
| # In this guardrail, if a user inputs `litellm` we will mask it. | |
| _messages = data.get("messages") | |
| if _messages: | |
| for message in _messages: | |
| _content = message.get("content") | |
| if isinstance(_content, str): | |
| if "litellm" in _content.lower(): | |
| raise ValueError("Guardrail failed words - `litellm` detected") | |
| async def async_post_call_success_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| response, | |
| ): | |
| """ | |
| Runs on response from LLM API call | |
| It can be used to reject a response | |
| If a response contains the word "coffee" -> we will raise an exception | |
| """ | |
| verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) | |
| if isinstance(response, litellm.ModelResponse): | |
| for choice in response.choices: | |
| if isinstance(choice, litellm.Choices): | |
| verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice) | |
| if ( | |
| choice.message.content | |
| and isinstance(choice.message.content, str) | |
| and "coffee" in choice.message.content | |
| ): | |
| raise ValueError("Guardrail failed Coffee Detected") | |