| | import os |
| | from literalai import AsyncLiteralClient |
| | from datetime import datetime, timedelta, timezone |
| | from modules.config.constants import COOLDOWN_TIME, TOKENS_LEFT, REGEN_TIME |
| | from typing_extensions import TypedDict |
| | import tiktoken |
| | from typing import Any, Generic, List, Literal, Optional, TypeVar, Union |
| |
|
| | Field = TypeVar("Field") |
| | Operators = TypeVar("Operators") |
| | Value = TypeVar("Value") |
| |
|
| | BOOLEAN_OPERATORS = Literal["is", "nis"] |
| | STRING_OPERATORS = Literal["eq", "neq", "ilike", "nilike"] |
| | NUMBER_OPERATORS = Literal["eq", "neq", "gt", "gte", "lt", "lte"] |
| | STRING_LIST_OPERATORS = Literal["in", "nin"] |
| | DATETIME_OPERATORS = Literal["gte", "lte", "gt", "lt"] |
| |
|
| | OPERATORS = Union[ |
| | BOOLEAN_OPERATORS, |
| | STRING_OPERATORS, |
| | NUMBER_OPERATORS, |
| | STRING_LIST_OPERATORS, |
| | DATETIME_OPERATORS, |
| | ] |
| |
|
| |
|
| | class Filter(Generic[Field], TypedDict, total=False): |
| | field: Field |
| | operator: OPERATORS |
| | value: Any |
| | path: Optional[str] |
| |
|
| |
|
| | class OrderBy(Generic[Field], TypedDict): |
| | column: Field |
| | direction: Literal["ASC", "DESC"] |
| |
|
| |
|
| | threads_filterable_fields = Literal[ |
| | "id", |
| | "createdAt", |
| | "name", |
| | "stepType", |
| | "stepName", |
| | "stepOutput", |
| | "metadata", |
| | "tokenCount", |
| | "tags", |
| | "participantId", |
| | "participantIdentifiers", |
| | "scoreValue", |
| | "duration", |
| | ] |
| | threads_orderable_fields = Literal["createdAt", "tokenCount"] |
| | threads_filters = List[Filter[threads_filterable_fields]] |
| | threads_order_by = OrderBy[threads_orderable_fields] |
| |
|
| | steps_filterable_fields = Literal[ |
| | "id", |
| | "name", |
| | "input", |
| | "output", |
| | "participantIdentifier", |
| | "startTime", |
| | "endTime", |
| | "metadata", |
| | "parentId", |
| | "threadId", |
| | "error", |
| | "tags", |
| | ] |
| | steps_orderable_fields = Literal["createdAt"] |
| | steps_filters = List[Filter[steps_filterable_fields]] |
| | steps_order_by = OrderBy[steps_orderable_fields] |
| |
|
| | users_filterable_fields = Literal[ |
| | "id", |
| | "createdAt", |
| | "identifier", |
| | "lastEngaged", |
| | "threadCount", |
| | "tokenCount", |
| | "metadata", |
| | ] |
| | users_filters = List[Filter[users_filterable_fields]] |
| |
|
| | scores_filterable_fields = Literal[ |
| | "id", |
| | "createdAt", |
| | "participant", |
| | "name", |
| | "tags", |
| | "value", |
| | "type", |
| | "comment", |
| | ] |
| | scores_orderable_fields = Literal["createdAt"] |
| | scores_filters = List[Filter[scores_filterable_fields]] |
| | scores_order_by = OrderBy[scores_orderable_fields] |
| |
|
| | generation_filterable_fields = Literal[ |
| | "id", |
| | "createdAt", |
| | "model", |
| | "duration", |
| | "promptLineage", |
| | "promptVersion", |
| | "tags", |
| | "score", |
| | "participant", |
| | "tokenCount", |
| | "error", |
| | ] |
| | generation_orderable_fields = Literal[ |
| | "createdAt", |
| | "tokenCount", |
| | "model", |
| | "provider", |
| | "participant", |
| | "duration", |
| | ] |
| | generations_filters = List[Filter[generation_filterable_fields]] |
| | generations_order_by = OrderBy[generation_orderable_fields] |
| |
|
| | literal_client = AsyncLiteralClient(api_key=os.getenv("LITERAL_API_KEY_LOGGING")) |
| |
|
| |
|
| | |
| | def convert_to_dict(user_info): |
| | |
| | if isinstance(user_info, dict): |
| | return user_info |
| | if hasattr(user_info, "__dict__"): |
| | user_info = user_info.__dict__ |
| | return user_info |
| |
|
| |
|
| | def get_time(): |
| | return datetime.now(timezone.utc).isoformat() |
| |
|
| |
|
| | async def get_user_details(user_email_id): |
| | user_info = await literal_client.api.get_or_create_user(identifier=user_email_id) |
| | return user_info |
| |
|
| |
|
| | async def update_user_info(user_info): |
| | |
| | user_info = convert_to_dict(user_info) |
| | await literal_client.api.update_user( |
| | id=user_info["id"], |
| | identifier=user_info["identifier"], |
| | metadata=user_info["metadata"], |
| | ) |
| |
|
| |
|
| | async def check_user_cooldown(user_info, current_time): |
| | |
| | tokens_left = user_info.metadata.get("tokens_left", 0) |
| | if tokens_left > 0 and not user_info.metadata.get("in_cooldown", False): |
| | return False, None |
| |
|
| | user_info = convert_to_dict(user_info) |
| | last_message_time_str = user_info["metadata"].get("last_message_time") |
| |
|
| | |
| | last_message_time = datetime.fromisoformat(last_message_time_str).replace( |
| | tzinfo=timezone.utc |
| | ) |
| | current_time = datetime.fromisoformat(current_time).replace(tzinfo=timezone.utc) |
| |
|
| | |
| | elapsed_time = current_time - last_message_time |
| | elapsed_time_in_seconds = elapsed_time.total_seconds() |
| |
|
| | |
| | cooldown_end_time = last_message_time + timedelta(seconds=COOLDOWN_TIME) |
| | cooldown_end_time_iso = cooldown_end_time.isoformat() |
| |
|
| | |
| | print(f"Cooldown end time (ISO): {cooldown_end_time_iso}") |
| |
|
| | |
| | if elapsed_time_in_seconds < COOLDOWN_TIME: |
| | return True, cooldown_end_time_iso |
| |
|
| | user_info["metadata"]["in_cooldown"] = False |
| | |
| | await reset_tokens_for_user(user_info) |
| |
|
| | return False, None |
| |
|
| |
|
| | async def reset_tokens_for_user(user_info): |
| | user_info = convert_to_dict(user_info) |
| | last_message_time_str = user_info["metadata"].get("last_message_time") |
| |
|
| | last_message_time = datetime.fromisoformat(last_message_time_str).replace( |
| | tzinfo=timezone.utc |
| | ) |
| | current_time = datetime.fromisoformat(get_time()).replace(tzinfo=timezone.utc) |
| |
|
| | |
| | elapsed_time_in_seconds = (current_time - last_message_time).total_seconds() |
| |
|
| | |
| | current_tokens = user_info["metadata"].get("tokens_left_at_last_message", 0) |
| | current_tokens = min(current_tokens, TOKENS_LEFT) |
| |
|
| | |
| | max_tokens = user_info["metadata"].get("max_tokens", TOKENS_LEFT) |
| |
|
| | |
| | if current_tokens < max_tokens: |
| | |
| | regeneration_rate_per_second = max_tokens / REGEN_TIME |
| |
|
| | |
| | tokens_to_regenerate = int( |
| | elapsed_time_in_seconds * regeneration_rate_per_second |
| | ) |
| |
|
| | |
| | new_token_count = min(current_tokens + tokens_to_regenerate, max_tokens) |
| |
|
| | print( |
| | f"\n\n Adding {tokens_to_regenerate} tokens to the user, Time elapsed: {elapsed_time_in_seconds} seconds, Tokens after regeneration: {new_token_count}, Tokens before: {current_tokens} \n\n" |
| | ) |
| |
|
| | |
| | user_info["metadata"]["tokens_left"] = new_token_count |
| |
|
| | await update_user_info(user_info) |
| |
|
| |
|
| | async def get_thread_step_info(thread_id): |
| | step = await literal_client.api.get_step(thread_id) |
| | return step |
| |
|
| |
|
| | def get_num_tokens(text, model): |
| | encoding = tiktoken.encoding_for_model(model) |
| | tokens = encoding.encode(text) |
| | return len(tokens) |
| |
|