Spaces:
Paused
Paused
mattoofahad
commited on
Commit
·
f79cf2d
1
Parent(s):
ed24ac5
adding support ti run lm-studio
Browse files- .gitignore +2 -0
- src/app.py +1 -1
- src/utils/config.py +12 -12
- src/utils/constants.py +6 -5
- src/utils/logs.py +106 -98
- src/utils/openai_utils.py +38 -28
- src/utils/streamlit_utils.py +60 -31
.gitignore
CHANGED
|
@@ -158,3 +158,5 @@ cython_debug/
|
|
| 158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
#.idea/
|
|
|
|
|
|
|
|
|
| 158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
#.idea/
|
| 161 |
+
|
| 162 |
+
/notebooks
|
src/app.py
CHANGED
|
@@ -24,7 +24,7 @@ def main():
|
|
| 24 |
if (
|
| 25 |
st.session_state.openai_api_key is not None
|
| 26 |
and st.session_state.openai_api_key != ""
|
| 27 |
-
):
|
| 28 |
logger.info("OpenAI key Checking condition passed")
|
| 29 |
if OpenAIFunctions.check_openai_api_key():
|
| 30 |
logger.info("Inference Started")
|
|
|
|
| 24 |
if (
|
| 25 |
st.session_state.openai_api_key is not None
|
| 26 |
and st.session_state.openai_api_key != ""
|
| 27 |
+
) or st.session_state.provider_select != "OpenAI":
|
| 28 |
logger.info("OpenAI key Checking condition passed")
|
| 29 |
if OpenAIFunctions.check_openai_api_key():
|
| 30 |
logger.info("Inference Started")
|
src/utils/config.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
"""Module doc string"""
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
from dotenv import find_dotenv, load_dotenv
|
| 6 |
-
|
| 7 |
-
load_dotenv(find_dotenv(), override=True)
|
| 8 |
-
|
| 9 |
-
LOGGER_LEVEL = os.getenv("LOGGER_LEVEL", "INFO")
|
| 10 |
-
DISCORD_HOOK = os.getenv("DISCORD_HOOK", "NO_HOOK")
|
| 11 |
-
ENVIRONMENT = os.getenv("ENVIRONMENT", "NOT_LOCAL")
|
| 12 |
-
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "NO_KEY")
|
|
|
|
| 1 |
+
"""Module doc string"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from dotenv import find_dotenv, load_dotenv
|
| 6 |
+
|
| 7 |
+
load_dotenv(find_dotenv(), override=True)
|
| 8 |
+
|
| 9 |
+
LOGGER_LEVEL = os.getenv("LOGGER_LEVEL", "INFO")
|
| 10 |
+
DISCORD_HOOK = os.getenv("DISCORD_HOOK", "NO_HOOK")
|
| 11 |
+
ENVIRONMENT = os.getenv("ENVIRONMENT", "NOT_LOCAL")
|
| 12 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "NO_KEY")
|
src/utils/constants.py
CHANGED
|
@@ -12,13 +12,14 @@ class ConstantVariables:
|
|
| 12 |
"gpt-4-turbo",
|
| 13 |
"gpt-3.5-turbo",
|
| 14 |
"o1-preview",
|
| 15 |
-
"o1-mini"
|
| 16 |
)
|
|
|
|
| 17 |
default_model = "gpt-4o-mini"
|
| 18 |
-
|
| 19 |
-
max_tokens =
|
| 20 |
-
min_token =
|
| 21 |
-
step =
|
| 22 |
default = round(((max_tokens + min_token) / 2) / step) * step
|
| 23 |
default_token = max(min_token, min(max_tokens, default))
|
| 24 |
|
|
|
|
| 12 |
"gpt-4-turbo",
|
| 13 |
"gpt-3.5-turbo",
|
| 14 |
"o1-preview",
|
| 15 |
+
"o1-mini",
|
| 16 |
)
|
| 17 |
+
provider = ("lm-studio", "OpenAI")
|
| 18 |
default_model = "gpt-4o-mini"
|
| 19 |
+
default_provider = "lm-studio"
|
| 20 |
+
max_tokens = 1024
|
| 21 |
+
min_token = 32
|
| 22 |
+
step = 32
|
| 23 |
default = round(((max_tokens + min_token) / 2) / step) * step
|
| 24 |
default_token = max(min_token, min(max_tokens, default))
|
| 25 |
|
src/utils/logs.py
CHANGED
|
@@ -1,98 +1,106 @@
|
|
| 1 |
-
"""Module doc string"""
|
| 2 |
-
|
| 3 |
-
import asyncio
|
| 4 |
-
import logging
|
| 5 |
-
import sys
|
| 6 |
-
import time
|
| 7 |
-
from functools import wraps
|
| 8 |
-
|
| 9 |
-
from colorama import Back, Fore, Style, init
|
| 10 |
-
|
| 11 |
-
from .config import LOGGER_LEVEL
|
| 12 |
-
|
| 13 |
-
# Initialize colorama
|
| 14 |
-
init(autoreset=True)
|
| 15 |
-
|
| 16 |
-
logger = logging.getLogger(__name__)
|
| 17 |
-
|
| 18 |
-
if not logger.hasHandlers():
|
| 19 |
-
logger.propagate = False
|
| 20 |
-
logger.setLevel(LOGGER_LEVEL)
|
| 21 |
-
|
| 22 |
-
# Define color codes for different log levels
|
| 23 |
-
log_colors = {
|
| 24 |
-
logging.DEBUG: Fore.CYAN,
|
| 25 |
-
logging.INFO: Fore.GREEN,
|
| 26 |
-
logging.WARNING: Fore.YELLOW,
|
| 27 |
-
logging.ERROR: Fore.RED,
|
| 28 |
-
logging.CRITICAL: Fore.RED + Back.WHITE + Style.BRIGHT,
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
class ColoredFormatter(logging.Formatter):
|
| 32 |
-
"""Module doc string"""
|
| 33 |
-
|
| 34 |
-
def format(self, record):
|
| 35 |
-
"""Module doc string"""
|
| 36 |
-
|
| 37 |
-
levelno = record.levelno
|
| 38 |
-
color = log_colors.get(levelno, "")
|
| 39 |
-
|
| 40 |
-
# Format the message
|
| 41 |
-
message = record.getMessage()
|
| 42 |
-
|
| 43 |
-
# Format the rest of the log details
|
| 44 |
-
details = self._fmt % {
|
| 45 |
-
"asctime": self.formatTime(record, self.datefmt),
|
| 46 |
-
"levelname": record.levelname,
|
| 47 |
-
"module": record.module,
|
| 48 |
-
"funcName": record.funcName,
|
| 49 |
-
"lineno": record.lineno,
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
# Combine details and colored message
|
| 53 |
-
return
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
normal_handler
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module doc string"""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from functools import wraps
|
| 8 |
+
|
| 9 |
+
from colorama import Back, Fore, Style, init
|
| 10 |
+
|
| 11 |
+
from .config import LOGGER_LEVEL
|
| 12 |
+
|
| 13 |
+
# Initialize colorama
|
| 14 |
+
init(autoreset=True)
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
if not logger.hasHandlers():
|
| 19 |
+
logger.propagate = False
|
| 20 |
+
logger.setLevel(LOGGER_LEVEL)
|
| 21 |
+
|
| 22 |
+
# Define color codes for different log levels
|
| 23 |
+
log_colors = {
|
| 24 |
+
logging.DEBUG: Fore.CYAN,
|
| 25 |
+
logging.INFO: Fore.GREEN,
|
| 26 |
+
logging.WARNING: Fore.YELLOW,
|
| 27 |
+
logging.ERROR: Fore.RED,
|
| 28 |
+
logging.CRITICAL: Fore.RED + Back.WHITE + Style.BRIGHT,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
class ColoredFormatter(logging.Formatter):
|
| 32 |
+
"""Module doc string"""
|
| 33 |
+
|
| 34 |
+
def format(self, record):
|
| 35 |
+
"""Module doc string"""
|
| 36 |
+
|
| 37 |
+
levelno = record.levelno
|
| 38 |
+
color = log_colors.get(levelno, "")
|
| 39 |
+
|
| 40 |
+
# Format the message
|
| 41 |
+
message = record.getMessage()
|
| 42 |
+
|
| 43 |
+
# Format the rest of the log details
|
| 44 |
+
details = self._fmt % {
|
| 45 |
+
"asctime": self.formatTime(record, self.datefmt),
|
| 46 |
+
"levelname": record.levelname,
|
| 47 |
+
"module": record.module,
|
| 48 |
+
"funcName": record.funcName,
|
| 49 |
+
"lineno": record.lineno,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
# Combine details and colored message
|
| 53 |
+
return (
|
| 54 |
+
f"{Fore.WHITE}{details} :: {color}{message}{Style.RESET_ALL}"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
normal_handler = logging.StreamHandler(sys.stdout)
|
| 58 |
+
normal_handler.setLevel(logging.DEBUG)
|
| 59 |
+
normal_handler.addFilter(
|
| 60 |
+
lambda logRecord: logRecord.levelno < logging.WARNING
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
error_handler = logging.StreamHandler(sys.stderr)
|
| 64 |
+
error_handler.setLevel(logging.WARNING)
|
| 65 |
+
|
| 66 |
+
formatter = ColoredFormatter(
|
| 67 |
+
"%(asctime)s :: %(levelname)s :: %(module)s :: %(funcName)s :: %(lineno)d"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
normal_handler.setFormatter(formatter)
|
| 71 |
+
error_handler.setFormatter(formatter)
|
| 72 |
+
|
| 73 |
+
logger.addHandler(normal_handler)
|
| 74 |
+
logger.addHandler(error_handler)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def log_execution_time(func):
|
| 78 |
+
"""Module doc string"""
|
| 79 |
+
|
| 80 |
+
@wraps(func)
|
| 81 |
+
def sync_wrapper(*args, **kwargs):
|
| 82 |
+
start_time = time.time()
|
| 83 |
+
result = func(*args, **kwargs)
|
| 84 |
+
end_time = time.time()
|
| 85 |
+
execution_time = end_time - start_time
|
| 86 |
+
message_string = (
|
| 87 |
+
f"{func.__name__} executed in {execution_time:.4f} seconds"
|
| 88 |
+
)
|
| 89 |
+
logger.debug(message_string)
|
| 90 |
+
return result
|
| 91 |
+
|
| 92 |
+
@wraps(func)
|
| 93 |
+
async def async_wrapper(*args, **kwargs):
|
| 94 |
+
start_time = time.time()
|
| 95 |
+
result = await func(*args, **kwargs)
|
| 96 |
+
end_time = time.time()
|
| 97 |
+
execution_time = end_time - start_time
|
| 98 |
+
message_string = (
|
| 99 |
+
f"{func.__name__} executed in {execution_time:.4f} seconds"
|
| 100 |
+
)
|
| 101 |
+
logger.debug(message_string)
|
| 102 |
+
return result
|
| 103 |
+
|
| 104 |
+
if asyncio.iscoroutinefunction(func):
|
| 105 |
+
return async_wrapper
|
| 106 |
+
return sync_wrapper
|
src/utils/openai_utils.py
CHANGED
|
@@ -15,21 +15,27 @@ class OpenAIFunctions:
|
|
| 15 |
@staticmethod
|
| 16 |
def invoke_model():
|
| 17 |
"""_summary_"""
|
|
|
|
| 18 |
logger.debug("OpenAI invoked")
|
| 19 |
with st.chat_message("assistant"):
|
| 20 |
messages = [
|
| 21 |
{"role": m["role"], "content": m["content"]}
|
| 22 |
for m in st.session_state.messages
|
| 23 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
stream = completion(
|
| 26 |
-
api_key=st.session_state.openai_api_key,
|
| 27 |
-
model=st.session_state["openai_model"],
|
| 28 |
-
messages=messages,
|
| 29 |
-
max_tokens=st.session_state["openai_maxtokens"],
|
| 30 |
-
stream=True,
|
| 31 |
-
stream_options={"include_usage": True},
|
| 32 |
-
)
|
| 33 |
|
| 34 |
def stream_data():
|
| 35 |
for chunk in stream:
|
|
@@ -50,24 +56,28 @@ class OpenAIFunctions:
|
|
| 50 |
@staticmethod
|
| 51 |
def check_openai_api_key():
|
| 52 |
"""_summary_"""
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
client = OpenAI(api_key=st.session_state.openai_api_key)
|
| 56 |
-
client.models.list()
|
| 57 |
-
logger.debug("OpenAI key Working")
|
| 58 |
return True
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
@staticmethod
|
| 16 |
def invoke_model():
|
| 17 |
"""_summary_"""
|
| 18 |
+
|
| 19 |
logger.debug("OpenAI invoked")
|
| 20 |
with st.chat_message("assistant"):
|
| 21 |
messages = [
|
| 22 |
{"role": m["role"], "content": m["content"]}
|
| 23 |
for m in st.session_state.messages
|
| 24 |
]
|
| 25 |
+
comp_args = {}
|
| 26 |
+
if st.session_state.provider_select == "OpenAI":
|
| 27 |
+
comp_args["api_key"] = st.session_state.openai_api_key
|
| 28 |
+
comp_args["model"] = st.session_state["openai_model"]
|
| 29 |
+
elif st.session_state.provider_select == "lm-studio":
|
| 30 |
+
comp_args["base_url"] = "http://localhost:1234/v1"
|
| 31 |
+
comp_args["api_key"] = st.session_state.provider_select
|
| 32 |
+
comp_args["model"] = "gpt-4o-mini"
|
| 33 |
+
comp_args["messages"] = messages
|
| 34 |
+
comp_args["max_tokens"] = st.session_state["openai_maxtokens"]
|
| 35 |
+
comp_args["stream"] = True
|
| 36 |
+
comp_args["stream_options"] = {"include_usage": True}
|
| 37 |
|
| 38 |
+
stream = completion(**comp_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def stream_data():
|
| 41 |
for chunk in stream:
|
|
|
|
| 56 |
@staticmethod
|
| 57 |
def check_openai_api_key():
|
| 58 |
"""_summary_"""
|
| 59 |
+
if st.session_state.provider_select == "lm-studio":
|
| 60 |
+
logger.info("Local Provider is Sekected")
|
|
|
|
|
|
|
|
|
|
| 61 |
return True
|
| 62 |
+
else:
|
| 63 |
+
logger.info("Checking OpenAI Key")
|
| 64 |
+
try:
|
| 65 |
+
client = OpenAI(api_key=st.session_state.openai_api_key)
|
| 66 |
+
client.models.list()
|
| 67 |
+
logger.debug("OpenAI key Working")
|
| 68 |
+
return True
|
| 69 |
+
except openai.AuthenticationError as auth_error:
|
| 70 |
+
with st.chat_message("assistant"):
|
| 71 |
+
st.error(str(auth_error))
|
| 72 |
+
logger.error("AuthenticationError: %s", auth_error)
|
| 73 |
+
return False
|
| 74 |
+
except openai.OpenAIError as openai_error:
|
| 75 |
+
with st.chat_message("assistant"):
|
| 76 |
+
st.error(str(openai_error))
|
| 77 |
+
logger.error("OpenAIError: %s", openai_error)
|
| 78 |
+
return False
|
| 79 |
+
except Exception as general_error:
|
| 80 |
+
with st.chat_message("assistant"):
|
| 81 |
+
st.error(str(general_error))
|
| 82 |
+
logger.error("Unexpected error: %s", general_error)
|
| 83 |
+
return False
|
src/utils/streamlit_utils.py
CHANGED
|
@@ -25,36 +25,53 @@ class StreamlitFunctions:
|
|
| 25 |
def streamlit_side_bar():
|
| 26 |
"""_summary_"""
|
| 27 |
with st.sidebar:
|
| 28 |
-
st.text_input(
|
| 29 |
-
label="OpenAI API key",
|
| 30 |
-
value=ConstantVariables.api_key,
|
| 31 |
-
help="This will not be saved or stored.",
|
| 32 |
-
type="password",
|
| 33 |
-
key="api_key",
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
st.selectbox(
|
| 37 |
-
"Select
|
| 38 |
-
ConstantVariables.
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
st.slider(
|
| 42 |
-
"Max Tokens",
|
| 43 |
-
min_value=ConstantVariables.min_token,
|
| 44 |
-
max_value=ConstantVariables.max_tokens,
|
| 45 |
-
step=ConstantVariables.step,
|
| 46 |
-
key="openai_maxtokens",
|
| 47 |
-
)
|
| 48 |
-
st.button(
|
| 49 |
-
"Start Chat",
|
| 50 |
-
on_click=StreamlitFunctions.start_app,
|
| 51 |
-
use_container_width=True,
|
| 52 |
-
)
|
| 53 |
-
st.button(
|
| 54 |
-
"Reset History",
|
| 55 |
-
on_click=StreamlitFunctions.reset_history,
|
| 56 |
-
use_container_width=True,
|
| 57 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
@staticmethod
|
| 60 |
def streamlit_initialize_variables():
|
|
@@ -66,15 +83,23 @@ class StreamlitFunctions:
|
|
| 66 |
if "openai_model" not in st.session_state:
|
| 67 |
st.session_state["openai_model"] = ConstantVariables.default_model
|
| 68 |
|
|
|
|
|
|
|
|
|
|
| 69 |
if "openai_api_key" not in st.session_state:
|
| 70 |
st.session_state["openai_api_key"] = None
|
| 71 |
|
| 72 |
if "openai_maxtokens" not in st.session_state:
|
| 73 |
-
st.session_state["openai_maxtokens"] =
|
|
|
|
|
|
|
| 74 |
|
| 75 |
if "start_app" not in st.session_state:
|
| 76 |
st.session_state["start_app"] = False
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
@staticmethod
|
| 79 |
def reset_history():
|
| 80 |
"""_summary_"""
|
|
@@ -102,7 +127,11 @@ class StreamlitFunctions:
|
|
| 102 |
if prompt := st.chat_input("Type your Query"):
|
| 103 |
with st.chat_message("user"):
|
| 104 |
st.markdown(prompt)
|
| 105 |
-
st.session_state.messages.append(
|
|
|
|
|
|
|
| 106 |
response = OpenAIFunctions.invoke_model()
|
| 107 |
logger.debug(response)
|
| 108 |
-
st.session_state.messages.append(
|
|
|
|
|
|
|
|
|
| 25 |
def streamlit_side_bar():
|
| 26 |
"""_summary_"""
|
| 27 |
with st.sidebar:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
st.selectbox(
|
| 29 |
+
"Select Provider",
|
| 30 |
+
ConstantVariables.provider,
|
| 31 |
+
placeholder="Choose an option",
|
| 32 |
+
key="provider_select",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
)
|
| 34 |
+
if st.session_state.provider_select is not None:
|
| 35 |
+
if st.session_state.provider_select == "OpenAI":
|
| 36 |
+
st.text_input(
|
| 37 |
+
label="OpenAI API key",
|
| 38 |
+
value=ConstantVariables.api_key,
|
| 39 |
+
help="This will not be saved or stored.",
|
| 40 |
+
type="password",
|
| 41 |
+
key="api_key",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
st.selectbox(
|
| 45 |
+
"Select the GPT model",
|
| 46 |
+
ConstantVariables.model_list_tuple,
|
| 47 |
+
key="openai_model",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
elif st.session_state.provider_select == "lm-studio":
|
| 51 |
+
st.header("NOTE")
|
| 52 |
+
st.text(
|
| 53 |
+
"lm-studio is configured to work on `http://localhost:1234/v1`"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
st.slider(
|
| 57 |
+
"Max Tokens",
|
| 58 |
+
min_value=ConstantVariables.min_token,
|
| 59 |
+
max_value=ConstantVariables.max_tokens,
|
| 60 |
+
step=ConstantVariables.step,
|
| 61 |
+
key="openai_maxtokens",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
st.button(
|
| 65 |
+
"Start Chat",
|
| 66 |
+
on_click=StreamlitFunctions.start_app,
|
| 67 |
+
use_container_width=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
st.button(
|
| 71 |
+
"Reset History",
|
| 72 |
+
on_click=StreamlitFunctions.reset_history,
|
| 73 |
+
use_container_width=True,
|
| 74 |
+
)
|
| 75 |
|
| 76 |
@staticmethod
|
| 77 |
def streamlit_initialize_variables():
|
|
|
|
| 83 |
if "openai_model" not in st.session_state:
|
| 84 |
st.session_state["openai_model"] = ConstantVariables.default_model
|
| 85 |
|
| 86 |
+
if "provider_select" not in st.session_state:
|
| 87 |
+
st.session_state["provider_select"] = None
|
| 88 |
+
|
| 89 |
if "openai_api_key" not in st.session_state:
|
| 90 |
st.session_state["openai_api_key"] = None
|
| 91 |
|
| 92 |
if "openai_maxtokens" not in st.session_state:
|
| 93 |
+
st.session_state["openai_maxtokens"] = (
|
| 94 |
+
ConstantVariables.default_token
|
| 95 |
+
)
|
| 96 |
|
| 97 |
if "start_app" not in st.session_state:
|
| 98 |
st.session_state["start_app"] = False
|
| 99 |
|
| 100 |
+
if "api_key" not in st.session_state:
|
| 101 |
+
st.session_state["api_key"] = None
|
| 102 |
+
|
| 103 |
@staticmethod
|
| 104 |
def reset_history():
|
| 105 |
"""_summary_"""
|
|
|
|
| 127 |
if prompt := st.chat_input("Type your Query"):
|
| 128 |
with st.chat_message("user"):
|
| 129 |
st.markdown(prompt)
|
| 130 |
+
st.session_state.messages.append(
|
| 131 |
+
{"role": "user", "content": prompt}
|
| 132 |
+
)
|
| 133 |
response = OpenAIFunctions.invoke_model()
|
| 134 |
logger.debug(response)
|
| 135 |
+
st.session_state.messages.append(
|
| 136 |
+
{"role": "assistant", "content": response[0]}
|
| 137 |
+
)
|