Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os | |
| import time | |
| import json | |
| import re | |
| from typing import List, Literal, TypedDict | |
| from transformers import AutoTokenizer | |
| from tools.tools import toolsInfo | |
| from gradio_client import Client | |
| import constants as C | |
| import utils as U | |
| from openai import OpenAI | |
| import anthropic | |
| from groq import Groq | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| ModelType = Literal["GPT4", "CLAUDE", "LLAMA"] | |
| ModelConfig = TypedDict("ModelConfig", { | |
| "client": OpenAI | Groq | anthropic.Anthropic, | |
| "model": str, | |
| "max_context": int, | |
| "tokenizer": AutoTokenizer | |
| }) | |
| modelType: ModelType = os.environ.get("MODEL_TYPE") or "LLAMA" | |
| MODEL_CONFIG: dict[ModelType, ModelConfig] = { | |
| "GPT4": { | |
| "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), | |
| "model": "gpt-4o-mini", | |
| "max_context": 128000, | |
| "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") | |
| }, | |
| "CLAUDE": { | |
| "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), | |
| "model": "claude-3-5-sonnet-20240620", | |
| "max_context": 128000, | |
| "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") | |
| }, | |
| "LLAMA": { | |
| "client": Groq(api_key=os.environ.get("GROQ_API_KEY")), | |
| "model": "llama-3.1-70b-versatile", | |
| "tools_model": "llama3-groq-70b-8192-tool-use-preview", | |
| "max_context": 128000, | |
| "tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer") | |
| } | |
| } | |
| client = MODEL_CONFIG[modelType]["client"] | |
| MODEL = MODEL_CONFIG[modelType]["model"] | |
| TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL | |
| MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] | |
| tokenizer = MODEL_CONFIG[modelType]["tokenizer"] | |
| isClaudeModel = modelType == "CLAUDE" | |
| def __countTokens(text): | |
| text = str(text) | |
| tokens = tokenizer.encode(text, add_special_tokens=False) | |
| return len(tokens) | |
| st.set_page_config( | |
| page_title="Mini Perplexity", | |
| page_icon=C.AI_ICON, | |
| # menu_items={"About": None} | |
| ) | |
| def __isInvalidResponse(response: str): | |
| if len(re.findall(r'\n((?!http)[a-z])', response)) > 3: | |
| U.pprint("new line followed by small case char") | |
| return True | |
| if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1: | |
| U.pprint("lot of repeating words") | |
| return True | |
| if len(re.findall(r'\n\n', response)) > 20: | |
| U.pprint("lots of paragraphs") | |
| return True | |
| if C.EXCEPTION_KEYWORD in response: | |
| U.pprint("LLM API threw exception") | |
| return True | |
| # # json response without json separator | |
| # if ('{\n "options"' in response) and (C.JSON_SEPARATOR not in response): | |
| # return True | |
| # if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response): | |
| # return True | |
| # # only options with no text | |
| # if response.startswith(C.JSON_SEPARATOR): | |
| # return True | |
| def __matchingKeywordsCount(keywords: List[str], text: str): | |
| return sum([ | |
| 1 if keyword in text else 0 | |
| for keyword in keywords | |
| ]) | |
| def __getMessages(): | |
| def getContextSize(): | |
| currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100 | |
| U.pprint(f"{currContextSize=}") | |
| return currContextSize | |
| while getContextSize() > MAX_CONTEXT: | |
| U.pprint("Context size exceeded, removing first message") | |
| st.session_state.messages.pop(0) | |
| return st.session_state.messages | |
| def __logLlmRequest(messagesFormatted: list): | |
| contextSize = __countTokens(messagesFormatted) | |
| U.pprint(f"{contextSize=} | {MODEL}") | |
| # U.pprint(f"{messagesFormatted=}") | |
| tools = [ | |
| toolsInfo["getGoogleSearchResults"]["schema"], | |
| ] | |
| def __showToolResponse(toolResponseDisplay: dict): | |
| msg = toolResponseDisplay.get("text") | |
| icon = toolResponseDisplay.get("icon") | |
| col1, col2 = st.columns([1, 20]) | |
| with col1: | |
| st.image( | |
| icon or C.TOOL_ICON, | |
| width=30 | |
| ) | |
| with col2: | |
| if "`" not in msg: | |
| st.markdown(f"`{msg}`") | |
| else: | |
| st.markdown(msg) | |
| def __addToolCallToMsgs(toolCall: dict): | |
| if isClaudeModel: | |
| st.session_state.messages.append(toolCall) | |
| else: | |
| st.session_state.messages.append( | |
| { | |
| "role": "assistant", | |
| "tool_calls": [ | |
| { | |
| "id": toolCall.id, | |
| "function": { | |
| "name": toolCall.function.name, | |
| "arguments": toolCall.function.arguments, | |
| }, | |
| "type": toolCall.type, | |
| } | |
| ], | |
| } | |
| ) | |
| def __processToolCalls(toolCalls): | |
| for toolCall in toolCalls: | |
| functionName = toolCall.function.name | |
| functionToCall = toolsInfo[functionName]["func"] | |
| functionArgsStr = toolCall.function.arguments | |
| U.pprint(f"{functionName=} | {functionArgsStr=}") | |
| functionArgs = json.loads(functionArgsStr) | |
| functionResult = functionToCall(**functionArgs) | |
| functionResponse = functionResult.get("response") | |
| responseDisplay = functionResult.get("display") | |
| U.pprint(f"{functionResponse=}") | |
| if responseDisplay: | |
| __showToolResponse(responseDisplay) | |
| st.session_state.toolResponseDisplay = responseDisplay | |
| __addToolCallToMsgs(toolCall) | |
| st.session_state.messages.append({ | |
| "role": "tool", | |
| "tool_call_id": toolCall.id, | |
| "name": functionName, | |
| "content": functionResponse, | |
| }) | |
| def __processClaudeToolCalls(toolResponse): | |
| toolCall = toolResponse[1] | |
| functionName = toolCall.name | |
| functionToCall = toolsInfo[functionName]["func"] | |
| functionArgs = toolCall.input | |
| functionResult = functionToCall(**functionArgs) | |
| functionResponse = functionResult.get("response") | |
| responseDisplay = functionResult.get("display") | |
| U.pprint(f"{functionResponse=}") | |
| if responseDisplay: | |
| __showToolResponse(responseDisplay) | |
| st.session_state.toolResponseDisplay = responseDisplay | |
| __addToolCallToMsgs({ | |
| "role": "assistant", | |
| "content": toolResponse | |
| }) | |
| st.session_state.messages.append({ | |
| "role": "user", | |
| "content": [{ | |
| "type": "tool_result", | |
| "tool_use_id": toolCall.id, | |
| "content": functionResponse, | |
| }], | |
| }) | |
| def __dedupeToolCalls(toolCalls: list): | |
| toolCallsDict = {} | |
| for toolCall in toolCalls: | |
| funcName = toolCall.name if isClaudeModel else toolCall.function.name | |
| toolCallsDict[funcName] = toolCall | |
| dedupedToolCalls = list(toolCallsDict.values()) | |
| if len(toolCalls) != len(dedupedToolCalls): | |
| U.pprint("Deduped tool calls!") | |
| U.pprint(f"{toolCalls=} -> {dedupedToolCalls=}") | |
| return dedupedToolCalls | |
| def __getClaudeTools(): | |
| claudeTools = [] | |
| for tool in tools: | |
| funcInfo = tool["function"] | |
| name = funcInfo["name"] | |
| description = funcInfo["description"] | |
| schema = funcInfo["parameters"] | |
| claudeTools.append({ | |
| "name": name, | |
| "description": description, | |
| "input_schema": schema, | |
| }) | |
| return claudeTools | |
| def predict(model: str = None): | |
| model = model or MODEL | |
| messagesFormatted = [] | |
| try: | |
| if isClaudeModel: | |
| messagesFormatted.extend(__getMessages()) | |
| __logLlmRequest(messagesFormatted) | |
| responseMessage = client.messages.create( | |
| model=model, | |
| messages=messagesFormatted, | |
| system=C.SYSTEM_MSG, | |
| temperature=0.5, | |
| max_tokens=4000, | |
| tools=__getClaudeTools() | |
| ) | |
| responseMessageContent = responseMessage.content | |
| responseContent = responseMessageContent[0].text | |
| toolCalls = [] | |
| if len(responseMessageContent) > 1: | |
| toolCalls = [responseMessageContent[1]] | |
| else: | |
| messagesFormatted = [{"role": "system", "content": C.SYSTEM_MSG}] | |
| messagesFormatted.extend(__getMessages()) | |
| __logLlmRequest(messagesFormatted) | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messagesFormatted, | |
| temperature=0.7, | |
| max_tokens=4000, | |
| stream=False, | |
| tools=tools | |
| ) | |
| responseMessage = response.choices[0].message | |
| responseContent = responseMessage.content | |
| if responseContent and '<function=' in responseContent: | |
| U.pprint("Switching to TOOLS_MODEL") | |
| return predict(TOOLS_MODEL) | |
| toolCalls = responseMessage.tool_calls | |
| U.pprint(f"{responseMessage=}") | |
| U.pprint(f"{responseContent=}") | |
| U.pprint(f"{toolCalls=}") | |
| if toolCalls: | |
| toolCalls = __dedupeToolCalls(toolCalls) | |
| U.pprint("Deduping done!") | |
| try: | |
| if isClaudeModel: | |
| __processClaudeToolCalls(responseMessage.content) | |
| else: | |
| __processToolCalls(toolCalls) | |
| return predict() | |
| except Exception as e: | |
| U.pprint(e) | |
| else: | |
| return responseContent | |
| except Exception as e: | |
| U.pprint(f"LLM API Error: {e}") | |
| return C.EXCEPTION_KEYWORD | |
| def __generateImage(prompt: str): | |
| fluxClient = Client("black-forest-labs/FLUX.1-schnell") | |
| result = fluxClient.predict( | |
| prompt=prompt, | |
| seed=0, | |
| randomize_seed=True, | |
| width=1024, | |
| height=768, | |
| num_inference_steps=4, | |
| api_name="/infer" | |
| ) | |
| U.pprint(f"imageResult={result}") | |
| return result | |
| def __resetButtonState(): | |
| st.session_state.buttonValue = "" | |
| if "ipAddress" not in st.session_state: | |
| st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") | |
| if "chatHistory" not in st.session_state: | |
| st.session_state.chatHistory = [] | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "buttonValue" not in st.session_state: | |
| __resetButtonState() | |
| st.session_state.toolResponseDisplay = {} | |
| U.pprint("\n") | |
| U.pprint("\n") | |
| U.applyCommonStyles() | |
| st.title("Mini Perplexity 💡") | |
| for chat in st.session_state.chatHistory: | |
| role = chat["role"] | |
| content = chat["content"] | |
| imagePath = chat.get("image") | |
| toolResponseDisplay = chat.get("toolResponseDisplay") | |
| avatar = C.AI_ICON if role == "assistant" else C.USER_ICON | |
| with st.chat_message(role, avatar=avatar): | |
| st.markdown(content) | |
| if toolResponseDisplay: | |
| __showToolResponse(toolResponseDisplay) | |
| if imagePath: | |
| st.image(imagePath) | |
| # U.pprint(f"{st.session_state.buttonValue=}") | |
| # U.pprint(f"{st.session_state.selectedStory=}") | |
| # U.pprint(f"{st.session_state.startMsg=}") | |
| if prompt := ( | |
| st.chat_input("Ask anything") | |
| or st.session_state["buttonValue"] | |
| ): | |
| __resetButtonState() | |
| with st.chat_message("user", avatar=C.USER_ICON): | |
| st.markdown(prompt) | |
| U.pprint(f"{prompt=}") | |
| st.session_state.chatHistory.append({"role": "user", "content": prompt }) | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("assistant", avatar=C.AI_ICON): | |
| responseContainer = st.empty() | |
| def __printAndGetResponse(): | |
| response = "" | |
| responseContainer.image(C.TEXT_LOADER) | |
| responseGenerator = predict() | |
| for chunk in responseGenerator: | |
| response += chunk | |
| if __isInvalidResponse(response): | |
| U.pprint(f"InvalidResponse={response}") | |
| return | |
| if C.JSON_SEPARATOR not in response: | |
| responseContainer.markdown(response) | |
| return response | |
| response = __printAndGetResponse() | |
| while not response: | |
| U.pprint("Empty response. Retrying..") | |
| time.sleep(0.7) | |
| response = __printAndGetResponse() | |
| U.pprint(f"{response=}") | |
| def selectButton(optionLabel): | |
| st.session_state["buttonValue"] = optionLabel | |
| U.pprint(f"Selected: {optionLabel}") | |
| rawResponse = response | |
| responseParts = response.split(C.JSON_SEPARATOR) | |
| jsonStr = None | |
| if len(responseParts) > 1: | |
| [response, jsonStr] = responseParts | |
| imagePath = None | |
| # imageContainer = st.empty() | |
| # try: | |
| # (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response) | |
| # if imagePrompt: | |
| # imgContainer = imageContainer.container() | |
| # imgContainer.write( | |
| # f""" | |
| # <div class='blinking code'> | |
| # {loaderText} | |
| # </div> | |
| # """, | |
| # unsafe_allow_html=True | |
| # ) | |
| # # imgContainer.markdown(f"`{loaderText}`") | |
| # imgContainer.image(C.IMAGE_LOADER) | |
| # (imagePath, seed) = __generateImage(imagePrompt) | |
| # imageContainer.image(imagePath) | |
| # except Exception as e: | |
| # U.pprint(e) | |
| # imageContainer.empty() | |
| toolResponseDisplay = st.session_state.toolResponseDisplay | |
| st.session_state.chatHistory.append({ | |
| "role": "assistant", | |
| "content": response, | |
| "image": imagePath, | |
| "toolResponseDisplay": toolResponseDisplay | |
| }) | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": rawResponse, | |
| }) | |
| if jsonStr: | |
| try: | |
| json.loads(jsonStr) | |
| jsonObj = json.loads(jsonStr) | |
| options = jsonObj.get("options") | |
| action = jsonObj.get("action") | |
| if options: | |
| for option in options: | |
| st.button( | |
| option["label"], | |
| key=option["id"], | |
| on_click=lambda label=option["label"]: selectButton(label) | |
| ) | |
| elif action: | |
| pass | |
| except Exception as e: | |
| U.pprint(e) | |
| if "counter" not in st.session_state: | |
| st.session_state.counter = 1 | |
| st.session_state.counter += 1 | |
| import streamlit.components.v1 as components | |
| components.html( | |
| f"<p>{st.session_state.counter}</p>" | |
| """ | |
| <script> | |
| console.log("===== script running =====") | |
| const input = window.parent.document.querySelector('.stChatInput'); | |
| console.log({input}); | |
| </script> | |
| """, | |
| height=0 | |
| ) | |