Spaces:
Paused
Paused
| #### What this does #### | |
| # On success + failure, log events to lunary.ai | |
| import importlib | |
| import traceback | |
| from datetime import datetime, timezone | |
| import packaging | |
| # convert to {completion: xx, tokens: xx} | |
| def parse_usage(usage): | |
| return { | |
| "completion": usage["completion_tokens"] if "completion_tokens" in usage else 0, | |
| "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, | |
| } | |
| def parse_tool_calls(tool_calls): | |
| if tool_calls is None: | |
| return None | |
| def clean_tool_call(tool_call): | |
| serialized = { | |
| "type": tool_call.type, | |
| "id": tool_call.id, | |
| "function": { | |
| "name": tool_call.function.name, | |
| "arguments": tool_call.function.arguments, | |
| }, | |
| } | |
| return serialized | |
| return [clean_tool_call(tool_call) for tool_call in tool_calls] | |
| def parse_messages(input): | |
| if input is None: | |
| return None | |
| def clean_message(message): | |
| # if is string, return as is | |
| if isinstance(message, str): | |
| return message | |
| if "message" in message: | |
| return clean_message(message["message"]) | |
| serialized = { | |
| "role": message.get("role"), | |
| "content": message.get("content"), | |
| } | |
| # Only add tool_calls and function_call to res if they are set | |
| if message.get("tool_calls"): | |
| serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls")) | |
| return serialized | |
| if isinstance(input, list): | |
| if len(input) == 1: | |
| return clean_message(input[0]) | |
| else: | |
| return [clean_message(msg) for msg in input] | |
| else: | |
| return clean_message(input) | |
| class LunaryLogger: | |
| # Class variables or attributes | |
| def __init__(self): | |
| try: | |
| import lunary | |
| version = importlib.metadata.version("lunary") # type: ignore | |
| # if version < 0.1.43 then raise ImportError | |
| if packaging.version.Version(version) < packaging.version.Version("0.1.43"): # type: ignore | |
| print( # noqa | |
| "Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'" | |
| ) | |
| raise ImportError | |
| self.lunary_client = lunary | |
| except ImportError: | |
| print( # noqa | |
| "Lunary not installed. Please install it using 'pip install lunary'" | |
| ) # noqa | |
| raise ImportError | |
| def log_event( | |
| self, | |
| kwargs, | |
| type, | |
| event, | |
| run_id, | |
| model, | |
| print_verbose, | |
| extra={}, | |
| input=None, | |
| user_id=None, | |
| response_obj=None, | |
| start_time=datetime.now(timezone.utc), | |
| end_time=datetime.now(timezone.utc), | |
| error=None, | |
| ): | |
| try: | |
| print_verbose(f"Lunary Logging - Logging request for model {model}") | |
| template_id = None | |
| litellm_params = kwargs.get("litellm_params", {}) | |
| optional_params = kwargs.get("optional_params", {}) | |
| metadata = litellm_params.get("metadata", {}) or {} | |
| if optional_params: | |
| extra = {**extra, **optional_params} | |
| tags = metadata.get("tags", None) | |
| if extra: | |
| extra.pop("extra_body", None) | |
| extra.pop("user", None) | |
| template_id = extra.pop("extra_headers", {}).get("Template-Id", None) | |
| # keep only serializable types | |
| for param, value in extra.items(): | |
| if not isinstance(value, (str, int, bool, float)) and param != "tools": | |
| try: | |
| extra[param] = str(value) | |
| except Exception: | |
| pass | |
| if response_obj: | |
| usage = ( | |
| parse_usage(response_obj["usage"]) | |
| if "usage" in response_obj | |
| else None | |
| ) | |
| output = response_obj["choices"] if "choices" in response_obj else None | |
| else: | |
| usage = None | |
| output = None | |
| if error: | |
| error_obj = {"stack": error} | |
| else: | |
| error_obj = None | |
| self.lunary_client.track_event( # type: ignore | |
| type, | |
| "start", | |
| run_id, | |
| parent_run_id=metadata.get("parent_run_id", None), | |
| user_id=user_id, | |
| name=model, | |
| input=parse_messages(input), | |
| timestamp=start_time.astimezone(timezone.utc).isoformat(), | |
| template_id=template_id, | |
| metadata=metadata, | |
| runtime="litellm", | |
| tags=tags, | |
| params=extra, | |
| ) | |
| self.lunary_client.track_event( # type: ignore | |
| type, | |
| event, | |
| run_id, | |
| timestamp=end_time.astimezone(timezone.utc).isoformat(), | |
| runtime="litellm", | |
| error=error_obj, | |
| output=parse_messages(output), | |
| token_usage=usage, | |
| ) | |
| except Exception: | |
| print_verbose(f"Lunary Logging Error - {traceback.format_exc()}") | |
| pass | |