Spaces:
Runtime error
Runtime error
| import json | |
| from pymongo.mongo_client import MongoClient | |
| from pymongo.server_api import ServerApi | |
| from fastapi import Request, Response | |
| from . import log_module, security, settings, chat_functions | |
| from datetime import timezone, datetime, timedelta | |
| from pydantic import BaseModel, Field, PrivateAttr | |
| from typing import Any, List, Self | |
| import tiktoken | |
| import uuid | |
| client: MongoClient = MongoClient(settings.DB_URI, server_api=ServerApi('1')) | |
| class __DB: | |
| user = client.ChatDB.users | |
| sess = client.ChatDB.sessions | |
| personalities = client.ChatDB.personalityCores | |
| DB: __DB = __DB() | |
| tz: timezone = timezone(timedelta(hours=-4)) | |
| encoding = tiktoken.get_encoding("o200k_base") | |
| def count_tokens_on_message(args: list) -> int: | |
| token_count = 0 | |
| for n in args: | |
| if n: token_count += len(encoding.encode(n)) | |
| return token_count | |
| def get_personality_core (personality): | |
| found = DB.personalities.find_one({"name":personality}) | |
| return found["prompt"].replace( | |
| "{date}", datetime.now().strftime("%Y-%m-%D") | |
| ) | |
| def get_all_personality_cores(): | |
| return [x["name"] for x in DB.personalities.find({})] | |
| class Configs(BaseModel): | |
| temperature: float = 0.5 | |
| frequency_penalty: float = 0.0 | |
| presence_penalty: float = 0.0 | |
| useTool: bool = True | |
| assistant: str = "Chatsito clasico" | |
| assistantPrompt: str = get_personality_core("Chatsito clasico") | |
| def __init__(self, *args, **kwargs): | |
| super(Configs, self).__init__(*args, **kwargs) | |
| self.assistantPrompt = get_personality_core(self.assistant) | |
| class Message(BaseModel): | |
| role: str | |
| content: str = "" | |
| _tokens: int = 1 | |
| _thread = None | |
| _running:bool = False | |
| def __init__(self, *args, **kwargs): | |
| super(Message, self).__init__(*args, **kwargs) | |
| self._tokens = count_tokens_on_message([self.content]) | |
| def consume_stream(self, stream): | |
| self._running = True | |
| for chunk in stream: | |
| self._tokens += 1 | |
| choice = chunk.choices[0] | |
| if choice.finish_reason == "stop" or not choice.delta.content: | |
| return | |
| self.content += choice.delta.content | |
| self._running = False | |
| def get_tokens(self): | |
| return self._tokens, self._tokensOutput | |
| class ToolCallsInputFunction(BaseModel): | |
| name: str = "" | |
| arguments: str = "" | |
| class ToolCallsInput(BaseModel): | |
| id: str | |
| function: ToolCallsInputFunction | |
| type: str = "function" | |
| class ToolCallsOutput(BaseModel): | |
| role: str | |
| tool_call_id: str | |
| name: str | |
| content: str | |
| _tokens: int = 0 | |
| def __init__(self, *args, **kwargs): | |
| super(ToolCallsOutput, self).__init__(*args, **kwargs) | |
| self._tokens = count_tokens_on_message([self.content]) | |
| class MessageTool(BaseModel): | |
| role: str | |
| content: str = "" | |
| tool_calls: list[ToolCallsInput] | |
| _tokens: int = 0 | |
| _outputs: list[ToolCallsOutput] = [] | |
| def __init__(self, **kwargs): | |
| stream = kwargs.pop("stream") | |
| chunk = kwargs.pop("chunk") | |
| kwargs["tool_calls"] = [] | |
| super(MessageTool, self).__init__(**kwargs) | |
| while True: | |
| choice = chunk.choices[0] | |
| self._tokens += 1 | |
| if choice.finish_reason: | |
| break | |
| if chunk.choices[0].delta.tool_calls == None: | |
| chunk = next(stream) | |
| continue | |
| tool_call = chunk.choices[0].delta.tool_calls[0] | |
| if tool_call.id: | |
| self.tool_calls.append( | |
| ToolCallsInput( | |
| id=tool_call.id, | |
| function=ToolCallsInputFunction( | |
| **tool_call.function.model_dump() | |
| ))) | |
| elif tool_call.function.arguments: | |
| self.tool_calls[-1].function.arguments += tool_call.function.arguments | |
| chunk = next(stream) | |
| if not chunk: | |
| self._tokens += sum | |
| break | |
| def exec(self, gid): | |
| for func in self.tool_calls: | |
| self._outputs.append(ToolCallsOutput( | |
| role="tool", | |
| tool_call_id=func.id, | |
| name=func.function.name , | |
| content=chat_functions.function_callbacks[func.function.name](json.loads(func.function.arguments), gid) | |
| )) | |
| class Chat(BaseModel): | |
| messages: List[Message|MessageTool|ToolCallsOutput] | |
| tokens: int = 3 | |
| def __init__(self: Self, *args: list, **kwargs: dict): | |
| temp = [] | |
| for i in kwargs.pop("messages"): | |
| temp.append(Message(**i)) | |
| kwargs["messages"] = temp | |
| super(Chat, self).__init__(*args, **kwargs) | |
| self.tokens += sum([x._tokens+3 for x in self.messages]) | |
| def append(self: Self, message: Message|MessageTool): | |
| if isinstance(message, Message): | |
| self.messages.append(message) | |
| self.tokens += message._tokens | |
| else: | |
| self.messages.append(message) | |
| self.tokens += message._tokens | |
| for out in message._outputs: | |
| self.tokens += out._tokens | |
| self.messages.append(out) | |
| def new_msg(cls: Self, role: str, stream): | |
| message = Message(role=role) | |
| message.consume_stream(stream) | |
| return message | |
| def new_func(cls: Self, role: str, stream, chunk): | |
| return MessageTool(role=role, stream=stream, chunk=chunk) | |
| class Session(BaseModel): | |
| gid: str | |
| fprint: str | |
| hashed: str | |
| guid: str | |
| public_key: str = "" | |
| challenge: str = str(uuid.uuid4()) | |
| data: dict = {} | |
| configs: Configs | None = None | |
| def __init__(self, **kwargs): | |
| kwargs["guid"] = kwargs.get("guid", str(uuid.uuid4())) | |
| kwargs["hashed"] = security.sha256(kwargs["guid"] + kwargs["fprint"]) | |
| super(Session, self).__init__(**kwargs) | |
| def validate_signature(self: Self, signature:str): | |
| valid = security.validate_signature(self.public_key, signature, self.challenge) | |
| if not valid: | |
| security.raise_401("Cannot validate Session signature") | |
| return True | |
| def find_from_data(cls:Self, request: Request, data:dict) -> Self: | |
| cookie_data:dict = security.token_from_cookie(request) | |
| if "gid" not in cookie_data or "guid" not in cookie_data: | |
| log_module.logger().error("Cookie without session needed data") | |
| security.raise_401("gid or guid not in cookie") | |
| if not (public_key := cookie_data.get("public_key", None)): # FIX Vuln Code | |
| if request.scope["path"] != "/getToken": | |
| log_module.logger(cookie_data["gid"]).error(f"User without public key saved | {json.dumps(cookie_data)}") | |
| security.raise_401("the user must have a public key saved in token") | |
| else: | |
| log_module.logger(cookie_data['gid']).info("API public key set for user") | |
| public_key = data["public_key"] | |
| else: | |
| cls.check_challenge(data["fingerprint"], cookie_data["challenge"]) | |
| session: Self = cls( | |
| gid = cookie_data["gid"], | |
| fprint = data["fingerprint"], | |
| guid = cookie_data["guid"], | |
| public_key = public_key, | |
| data = data, | |
| configs = Configs(**cookie_data["configs"]) | |
| ) | |
| if session.hashed != cookie_data["fprint"]: | |
| log_module.logger(session.gid).error(f"Fingerprint didnt match | {json.dumps(cookie_data)}") | |
| security.raise_401("Fingerprint didnt match") | |
| session.add_challenge() | |
| return session | |
| def create_cookie_token(self:Self): | |
| return security.create_jwt_token({ | |
| "gid":self.gid, | |
| "guid": self.guid, | |
| "fprint": self.hashed, | |
| "public_key": self.public_key, | |
| "challenge": self.challenge, | |
| "configs": self.configs.model_dump() | |
| }) | |
| def create_cookie(self:Self, response: Response): | |
| jwt = self.create_cookie_token() | |
| security.set_cookie(response, "token", jwt, {"hours": 24}) | |
| def update_usage(self: Self, message:Message): | |
| User.update_usage(self.gid, message) | |
| def add_challenge(self: Self): | |
| return True | |
| self.challenge = str(uuid.uuid4()) | |
| DB.sess.insert_one(self.model_dump(include={"fprint", "challenge"}) ) | |
| def check_challenge(fprint:str, challenge: str): | |
| return True | |
| found = DB.sess.find_one_and_delete({"fprint":fprint}) | |
| if not found or found["challenge"] != challenge: | |
| security.raise_401("Check challenge failed") | |
| class User(BaseModel): | |
| name: str | |
| tokens: dict = {} | |
| created: datetime = datetime.now(tz) | |
| approved: datetime | None = None | |
| description: str = "" | |
| email: str | |
| gid: str | |
| role: str = "on hold" | |
| configs: Configs = Configs() | |
| _session: Session | None = None | |
| _data: dict | None = None | |
| def find_or_create(cls: Self, data: dict, loginData: dict)-> Self: | |
| found = DB.user.find_one({"gid":data["gid"]}) | |
| user:Self = cls(**found) if found else cls(**data) | |
| if not found: | |
| DB.user.insert_one(user.model_dump()) | |
| user._session = Session(gid=user.gid, fprint=loginData["fp"], public_key=loginData["pk"]) | |
| log_module.logger(user.gid).info(f"User {'logged' if found else'created'} | fp: {user._session.fprint}") | |
| return user | |
| def find_from_cookie(cls:Self, request: Request) -> Self: | |
| cookie_data:dict = security.token_from_cookie(request) | |
| if "gid" not in cookie_data or "guid" not in cookie_data: | |
| log_module.logger().error("Cookie without needed data") | |
| security.raise_307("gid or guid not in cookie") | |
| found:dict = DB.user.find_one({"gid":cookie_data["gid"]}) | |
| if not found: | |
| log_module.logger(cookie_data["gid"]).error("User not found on DB") | |
| security.raise_307("User not found on DB") | |
| user: Self = cls(**found) | |
| user._session = Session( | |
| gid = cookie_data["gid"], | |
| guid = cookie_data["guid"], | |
| fprint = cookie_data["fprint"], | |
| configs = user.configs | |
| ) | |
| return user | |
| def find_from_data(cls:Self, request: Request, data:dict) -> Self: | |
| session:Session = Session.find_from_data(request, data) | |
| found:dict = DB.user.find_one({"gid":session.gid}) | |
| if not found: | |
| log_module.logger(session.gid).error("User not found on DB") | |
| security.raise_307("User not found on DB") | |
| user: Self = cls(**found) | |
| user._session = session | |
| user._data = data | |
| return user | |
| def update_description(self: Self, message: str) -> None: | |
| log_module.logger(self.gid).info("Description Updated") | |
| DB.user.update_one( | |
| {"gid":self.gid}, | |
| {"$set": { "description": message}} | |
| ) | |
| self.description = message | |
| def can_use(self: Self, activity: str): | |
| return security.can_use(self.role, activity) | |
| def update_user(self: Self) -> None: | |
| log_module.logger(self.gid).info("User Updated") | |
| DB.user.update_one({"gid": self.gid}, {"$set": self.model_dump()}) | |
| return self.configs.assistantPrompt | |
| def update_usage(gid:str, tokens:int): | |
| inc_field = datetime.now().strftime("tokens.%y.%m.%d") | |
| DB.user.update_one({"gid": gid}, {"$inc":{inc_field: tokens}}) | |
| def create_cookie(self:Self): | |
| return security.create_jwt_token({ | |
| "gid":self._session.gid, | |
| "guid": self._session.guid, | |
| "fprint": self._session.hashed, | |
| "public_key": self._session.public_key, | |
| "challenge": self._session.challenge, | |
| "configs": self.configs.model_dump() | |
| }) | |