Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,40 +1,95 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
|
| 6 |
-
headers = {
|
| 7 |
-
"accept": "application/json",
|
| 8 |
-
"content-type": "application/json",
|
| 9 |
-
"Authorization": "Bearer cad84f39be62a8e36fdd846152dbb18abddef0aefcd921e82e287b4c228ac3e1"
|
| 10 |
-
}
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def main():
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
#
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
"
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
if __name__ == "__main__":
|
| 40 |
main()
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import threading
|
| 3 |
+
import os
|
| 4 |
+
import litellm
|
| 5 |
+
from litellm import completion
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
|
| 8 |
+
# load .env, so litellm reads from .env
|
| 9 |
+
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
litellm.token = "5fdb5efa-9307-40ed-b824-1c73a1613030"
|
| 12 |
+
|
| 13 |
+
models = []
|
| 14 |
+
provider_models_map = litellm.models_by_provider
|
| 15 |
+
for provider in provider_models_map:
|
| 16 |
+
print(provider)
|
| 17 |
+
for model in provider_models_map[provider]:
|
| 18 |
+
print(provider_models_map[provider])
|
| 19 |
+
models.append(provider+"/" + model)
|
| 20 |
+
|
| 21 |
+
# Function to get model outputs
|
| 22 |
+
def get_model_output(prompt, model_name):
|
| 23 |
+
try:
|
| 24 |
+
messages = [
|
| 25 |
+
{"role": "user", "content": prompt},
|
| 26 |
+
]
|
| 27 |
+
response = completion(messages=messages, model=model_name)
|
| 28 |
+
|
| 29 |
+
return response['choices'][0]['message']['content']
|
| 30 |
+
except Exception as e:
|
| 31 |
+
return f"got error calling LLM API {e}"
|
| 32 |
+
|
| 33 |
+
# Function to get model outputs
|
| 34 |
+
def get_model_output_thread(prompt, model_name, outputs, idx):
|
| 35 |
+
output = get_model_output(prompt, model_name)
|
| 36 |
+
outputs[idx] = output
|
| 37 |
+
|
| 38 |
+
# Streamlit app
|
| 39 |
def main():
|
| 40 |
+
keys = {}
|
| 41 |
+
st.title("LiteLLM Playground")
|
| 42 |
+
st.markdown("[LiteLLM - one package for CodeLlama, Llama2 Anthropic, Cohere, OpenAI, Replicate](https://github.com/BerriAI/litellm/)")
|
| 43 |
+
st.markdown("View Request Logs + Manage keys (Optional) [here:](https://admin.litellm.ai/5fdb5efa-9307-40ed-b824-1c73a1613030)")
|
| 44 |
+
|
| 45 |
+
# Sidebar for user input
|
| 46 |
+
with st.sidebar:
|
| 47 |
+
st.header("User Settings")
|
| 48 |
+
# List of models to test
|
| 49 |
+
model_names = models # Add your model names here
|
| 50 |
+
|
| 51 |
+
# Dropdowns for model selection
|
| 52 |
+
selected_models = []
|
| 53 |
+
for i in range(1):
|
| 54 |
+
selected_model = st.selectbox(f"Select Model {i+1}", model_names, index=i)
|
| 55 |
+
selected_models.append(selected_model)
|
| 56 |
+
|
| 57 |
+
provider = selected_model.split("/")[0]
|
| 58 |
+
key_name = f"{provider.upper()}_API_KEY"
|
| 59 |
+
api_key = st.text_input(f"Enter your {key_name}", type="password", key=i)
|
| 60 |
+
keys[key_name] = api_key
|
| 61 |
+
set_keys_button = st.button("Set API Keys")
|
| 62 |
+
|
| 63 |
+
if set_keys_button:
|
| 64 |
+
for key in keys:
|
| 65 |
+
if os.environ.get(key) != None: # if key not set in .env
|
| 66 |
+
os.environ[key] = keys[key]
|
| 67 |
+
st.success("API keys have been set.")
|
| 68 |
+
|
| 69 |
+
st.header("User Input")
|
| 70 |
+
prompt = st.text_area("Enter your prompt here:")
|
| 71 |
+
submit_button = st.button("Submit")
|
| 72 |
+
|
| 73 |
+
# Main content area to display model outputs
|
| 74 |
+
st.header("Model Outputs")
|
| 75 |
+
|
| 76 |
+
cols = st.columns(len(selected_models)) # Create columns
|
| 77 |
+
outputs = [""] * len(selected_models) # Initialize outputs list with empty strings
|
| 78 |
+
|
| 79 |
+
threads = []
|
| 80 |
+
if submit_button and prompt:
|
| 81 |
+
for idx, model_name in enumerate(selected_models):
|
| 82 |
+
thread = threading.Thread(target=get_model_output_thread, args=(prompt, model_name, outputs, idx))
|
| 83 |
+
threads.append(thread)
|
| 84 |
+
thread.start()
|
| 85 |
+
|
| 86 |
+
for thread in threads:
|
| 87 |
+
thread.join()
|
| 88 |
+
|
| 89 |
+
# Display text areas and fill with outputs if available
|
| 90 |
+
for idx, model_name in enumerate(selected_models):
|
| 91 |
+
with cols[idx]:
|
| 92 |
+
st.text_area(label=f"{model_name}", value=outputs[idx], height=300, key=f"output_{model_name}_{idx}") # Use a unique key
|
| 93 |
|
| 94 |
if __name__ == "__main__":
|
| 95 |
main()
|