| import json |
|
|
| from pymongo.mongo_client import MongoClient |
| from pymongo.server_api import ServerApi |
| from fastapi import Request, Response |
| from . import log_module, security, settings, chat_functions |
| from datetime import timezone, datetime, timedelta |
| from pydantic import BaseModel, Field, PrivateAttr |
| from typing import Any, List, Self |
| import tiktoken |
| import uuid |
|
|
| client: MongoClient = MongoClient(settings.DB_URI, server_api=ServerApi('1')) |
| class __DB: |
| user = client.ChatDB.users |
| sess = client.ChatDB.sessions |
| personalities = client.ChatDB.personalityCores |
| DB: __DB = __DB() |
|
|
| tz: timezone = timezone(timedelta(hours=-4)) |
| encoding = tiktoken.encoding_for_model(settings.GPT_MODEL) |
|
|
| def count_tokens_on_message(args: list) -> int: |
| token_count = 0 |
| for n in args: |
| if n: token_count += len(encoding.encode(n)) |
| return token_count |
|
|
| def get_personality_core (personality): |
| found = DB.personalities.find_one({"name":personality}) |
| return found["prompt"].replace( |
| "{date}", datetime.now().strftime("%Y-%m-%D") |
| ) |
|
|
| def get_all_personality_cores(): |
| return [x["name"] for x in DB.personalities.find({})] |
| |
| class Configs(BaseModel): |
| temperature: float = 0.5 |
| frequency_penalty: float = 0.0 |
| presence_penalty: float = 0.0 |
| useTool: bool = True |
| assistant: str = "Chatsito clasico" |
| assistantPrompt: str = get_personality_core("Chatsito clasico") |
| |
| def __init__(self, *args, **kwargs): |
| |
| super(Configs, self).__init__(*args, **kwargs) |
| self.assistantPrompt = get_personality_core(self.assistant) |
| |
| |
| |
| |
| class Message(BaseModel): |
| role: str |
| content: str = "" |
| _tokens: int = 1 |
| _thread = None |
| _running:bool = False |
|
|
| def __init__(self, *args, **kwargs): |
| super(Message, self).__init__(*args, **kwargs) |
| self._tokens = count_tokens_on_message([self.content]) |
|
|
| def consume_stream(self, stream): |
| self._running = True |
| for chunk in stream: |
| self._tokens += 1 |
| choice = chunk.choices[0] |
| if choice.finish_reason == "stop" or not choice.delta.content: |
| return |
| |
| self.content += choice.delta.content |
| |
| self._running = False |
| |
| def get_tokens(self): |
| return self._tokens, self._tokensOutput |
| |
| class ToolCallsInputFunction(BaseModel): |
| name: str = "" |
| arguments: str = "" |
| class ToolCallsInput(BaseModel): |
| id: str |
| function: ToolCallsInputFunction |
| type: str = "function" |
|
|
| class ToolCallsOutput(BaseModel): |
| role: str |
| tool_call_id: str |
| name: str |
| content: str |
| _tokens: int = 0 |
| |
| def __init__(self, *args, **kwargs): |
| |
| super(ToolCallsOutput, self).__init__(*args, **kwargs) |
| self._tokens = count_tokens_on_message([self.content]) |
| |
|
|
| class MessageTool(BaseModel): |
| role: str |
| content: str = "" |
| tool_calls: list[ToolCallsInput] |
| _tokens: int = 0 |
| _outputs: list[ToolCallsOutput] = [] |
| |
| def __init__(self, **kwargs): |
| stream = kwargs.pop("stream") |
| chunk = kwargs.pop("chunk") |
| kwargs["tool_calls"] = [] |
| super(MessageTool, self).__init__(**kwargs) |
|
|
| while True: |
| choice = chunk.choices[0] |
| self._tokens += 1 |
|
|
| if choice.finish_reason: |
| break |
| |
| if chunk.choices[0].delta.tool_calls == None: |
| chunk = next(stream) |
| continue |
| |
| tool_call = chunk.choices[0].delta.tool_calls[0] |
| |
| if tool_call.id: |
| self.tool_calls.append( |
| ToolCallsInput( |
| id=tool_call.id, |
| function=ToolCallsInputFunction( |
| **tool_call.function.model_dump() |
| ))) |
| elif tool_call.function.arguments: |
| self.tool_calls[-1].function.arguments += tool_call.function.arguments |
| |
| |
| chunk = next(stream) |
| if not chunk: |
| self._tokens += sum |
| break |
| |
| def exec(self, gid): |
| for func in self.tool_calls: |
| self._outputs.append(ToolCallsOutput( |
| role="tool", |
| tool_call_id=func.id, |
| name=func.function.name , |
| content=chat_functions.function_callbacks[func.function.name](json.loads(func.function.arguments), gid) |
| )) |
|
|
|
|
| class Chat(BaseModel): |
| messages: List[Message|MessageTool|ToolCallsOutput] |
| tokens: int = 3 |
|
|
| def __init__(self: Self, *args: list, **kwargs: dict): |
| temp = [] |
| for i in kwargs.pop("messages"): |
| temp.append(Message(**i)) |
| kwargs["messages"] = temp |
| super(Chat, self).__init__(*args, **kwargs) |
| self.tokens += sum([x._tokens+3 for x in self.messages]) |
| |
| def append(self: Self, message: Message|MessageTool): |
| if isinstance(message, Message): |
| self.messages.append(message) |
| self.tokens += message._tokens |
| else: |
| self.messages.append(message) |
| self.tokens += message._tokens |
| for out in message._outputs: |
| self.tokens += out._tokens |
| self.messages.append(out) |
| |
| @classmethod |
| def new_msg(cls: Self, role: str, stream): |
| message = Message(role=role) |
| message.consume_stream(stream) |
| return message |
| |
| @classmethod |
| def new_func(cls: Self, role: str, stream, chunk): |
| return MessageTool(role=role, stream=stream, chunk=chunk) |
|
|
| class Session(BaseModel): |
| gid: str |
| fprint: str |
| hashed: str |
| guid: str |
| public_key: str = "" |
| challenge: str = str(uuid.uuid4()) |
| data: dict = {} |
| configs: Configs | None = None |
|
|
| def __init__(self, **kwargs): |
| kwargs["guid"] = kwargs.get("guid", str(uuid.uuid4())) |
| kwargs["hashed"] = security.sha256(kwargs["guid"] + kwargs["fprint"]) |
| super(Session, self).__init__(**kwargs) |
| |
|
|
| def validate_signature(self: Self, signature:str): |
| valid = security.validate_signature(self.public_key, signature, self.challenge) |
| if not valid: |
| security.raise_401("Cannot validate Session signature") |
| return True |
| |
| @classmethod |
| def find_from_data(cls:Self, request: Request, data:dict) -> Self: |
| cookie_data:dict = security.token_from_cookie(request) |
|
|
| if "gid" not in cookie_data or "guid" not in cookie_data: |
| log_module.logger().error("Cookie without session needed data") |
| security.raise_401("gid or guid not in cookie") |
| |
| if not (public_key := cookie_data.get("public_key", None)): |
| if request.scope["path"] != "/getToken": |
| log_module.logger(cookie_data["gid"]).error(f"User without public key saved | {json.dumps(cookie_data)}") |
| security.raise_401("the user must have a public key saved in token") |
| else: |
| log_module.logger(cookie_data['gid']).info("API public key set for user") |
| public_key = data["public_key"] |
| else: |
| cls.check_challenge(data["fingerprint"], cookie_data["challenge"]) |
|
|
| session: Self = cls( |
| gid = cookie_data["gid"], |
| fprint = data["fingerprint"], |
| guid = cookie_data["guid"], |
| public_key = public_key, |
| data = data, |
| configs = Configs(**cookie_data["configs"]) |
| ) |
| |
|
|
| if session.hashed != cookie_data["fprint"]: |
| log_module.logger(session.gid).error(f"Fingerprint didnt match | {json.dumps(cookie_data)}") |
| security.raise_401("Fingerprint didnt match") |
|
|
|
|
| session.add_challenge() |
| |
| return session |
| |
|
|
| def create_cookie_token(self:Self): |
| return security.create_jwt_token({ |
| "gid":self.gid, |
| "guid": self.guid, |
| "fprint": self.hashed, |
| "public_key": self.public_key, |
| "challenge": self.challenge, |
| "configs": self.configs.model_dump() |
| }) |
| |
| def create_cookie(self:Self, response: Response): |
| jwt = self.create_cookie_token() |
| security.set_cookie(response, "token", jwt, {"hours": 24}) |
| |
| def update_usage(self: Self, message:Message): |
| User.update_usage(self.gid, message) |
|
|
| def add_challenge(self: Self): |
| return True |
| self.challenge = str(uuid.uuid4()) |
| DB.sess.insert_one(self.model_dump(include={"fprint", "challenge"}) ) |
|
|
| @staticmethod |
| def check_challenge(fprint:str, challenge: str): |
| return True |
| found = DB.sess.find_one_and_delete({"fprint":fprint}) |
| if not found or found["challenge"] != challenge: |
| security.raise_401("Check challenge failed") |
| |
| |
| class User(BaseModel): |
| name: str |
| tokens: dict = {} |
| created: datetime = datetime.now(tz) |
| approved: datetime | None = None |
| description: str = "" |
| email: str |
| gid: str |
| role: str = "on hold" |
| configs: Configs = Configs() |
|
|
| _session: Session | None = None |
| _data: dict | None = None |
|
|
| @classmethod |
| def find_or_create(cls: Self, data: dict, loginData: dict)-> Self: |
| |
| found = DB.user.find_one({"gid":data["gid"]}) |
|
|
| user:Self = cls(**found) if found else cls(**data) |
| |
| if not found: |
| DB.user.insert_one(user.model_dump()) |
| |
| user._session = Session(gid=user.gid, fprint=loginData["fp"], public_key=loginData["pk"]) |
|
|
| log_module.logger(user.gid).info(f"User {'logged' if found else'created'} | fp: {user._session.fprint}") |
| return user |
| |
| @classmethod |
| def find_from_cookie(cls:Self, request: Request) -> Self: |
| cookie_data:dict = security.token_from_cookie(request) |
| |
| if "gid" not in cookie_data or "guid" not in cookie_data: |
| log_module.logger().error("Cookie without needed data") |
| security.raise_307("gid or guid not in cookie") |
|
|
| found:dict = DB.user.find_one({"gid":cookie_data["gid"]}) |
|
|
| if not found: |
| log_module.logger(cookie_data["gid"]).error("User not found on DB") |
| security.raise_307("User not found on DB") |
|
|
| user: Self = cls(**found) |
| user._session = Session( |
| gid = cookie_data["gid"], |
| guid = cookie_data["guid"], |
| fprint = cookie_data["fprint"], |
| configs = user.configs |
| ) |
| return user |
| |
| @classmethod |
| def find_from_data(cls:Self, request: Request, data:dict) -> Self: |
| session:Session = Session.find_from_data(request, data) |
|
|
| found:dict = DB.user.find_one({"gid":session.gid}) |
|
|
| if not found: |
| log_module.logger(session.gid).error("User not found on DB") |
| security.raise_307("User not found on DB") |
|
|
| user: Self = cls(**found) |
| user._session = session |
| user._data = data |
| return user |
| |
| def update_description(self: Self, message: str) -> None: |
| log_module.logger(self.gid).info("Description Updated") |
| DB.user.update_one( |
| {"gid":self.gid}, |
| {"$set": { "description": message}} |
| ) |
| self.description = message |
|
|
| def can_use(self: Self, activity: str): |
| return security.can_use(self.role, activity) |
| |
| def update_user(self: Self) -> None: |
| log_module.logger(self.gid).info("User Updated") |
| DB.user.update_one({"gid": self.gid}, {"$set": self.model_dump()}) |
| return self.configs.assistantPrompt |
|
|
| @staticmethod |
| def update_usage(gid:str, tokens:int): |
| inc_field = datetime.now().strftime("tokens.%y.%m.%d") |
| DB.user.update_one({"gid": gid}, {"$inc":{inc_field: tokens}}) |
|
|
| def create_cookie(self:Self): |
| return security.create_jwt_token({ |
| "gid":self._session.gid, |
| "guid": self._session.guid, |
| "fprint": self._session.hashed, |
| "public_key": self._session.public_key, |
| "challenge": self._session.challenge, |
| "configs": self.configs.model_dump() |
| }) |
| |
|
|
| |