Amironox commited on
Commit
fb2cbf4
·
verified ·
1 Parent(s): 5a0371a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +9 -14
src/streamlit_app.py CHANGED
@@ -11,16 +11,10 @@ endpoint_data = json.load(open(f"{working_dir}/model_info.json"))
11
  def clear_chat():
12
  st.session_state.messages = []
13
 
14
- def get_api_key():
15
- # Retrieve API key from environment variable or prompt user
16
- return os.getenv("OPENAI_API_KEY") or st.text_input("Enter your API Key", type="password")
17
-
18
- st.title("AIaaS on Intel® Gaudi® Demo")
19
-
20
- # Extract endpoint and model names from JSON data
21
- endpoint = endpoint_data['endpoint']
22
- model_names = endpoint_data['models']
23
 
 
 
24
 
25
  with st.sidebar:
26
  modelname = st.selectbox("Select a LLM model (Running on Intel® Gaudi®) ", model_names)
@@ -28,7 +22,7 @@ with st.sidebar:
28
  st.button("Start New Chat", on_click=clear_chat)
29
 
30
  # Add a text input for the API key
31
- api_key = get_api_key()
32
  if api_key:
33
  st.session_state.api_key = api_key
34
 
@@ -37,12 +31,15 @@ if "api_key" not in st.session_state or not st.session_state.api_key:
37
  st.error("Please enter your API Key in the sidebar.")
38
  else:
39
  try:
 
 
40
  api_key = st.session_state.api_key
41
  base_url = endpoint
42
  client = OpenAI(api_key=api_key, base_url=base_url)
43
 
44
- print(f"Selected Model --> {modelname}")
45
- st.write(f"**Model Info:** `{modelname}`")
 
46
 
47
  if "messages" not in st.session_state:
48
  st.session_state.messages = []
@@ -65,10 +62,8 @@ else:
65
  for m in st.session_state.messages
66
  ],
67
  max_tokens=1024,
68
- temperature=0,
69
  stream=True,
70
  )
71
-
72
  response = st.write_stream(stream)
73
  except Exception as e:
74
  st.error(f"An error occurred while generating the response: {e}")
 
11
  def clear_chat():
12
  st.session_state.messages = []
13
 
14
+ st.title("Inference as a Service Playground")
 
 
 
 
 
 
 
 
15
 
16
+ # Extract the keys (model names) from the JSON data
17
+ model_names = list(endpoint_data.keys())
18
 
19
  with st.sidebar:
20
  modelname = st.selectbox("Select a LLM model (Running on Intel® Gaudi®) ", model_names)
 
22
  st.button("Start New Chat", on_click=clear_chat)
23
 
24
  # Add a text input for the API key
25
+ api_key = st.text_input("Enter your API Key", type="password")
26
  if api_key:
27
  st.session_state.api_key = api_key
28
 
 
31
  st.error("Please enter your API Key in the sidebar.")
32
  else:
33
  try:
34
+ endpoint = endpoint_data[modelname]
35
+
36
  api_key = st.session_state.api_key
37
  base_url = endpoint
38
  client = OpenAI(api_key=api_key, base_url=base_url)
39
 
40
+ # Extract the model name
41
+ models = client.models.list()
42
+ modelname = models.data[0].id
43
 
44
  if "messages" not in st.session_state:
45
  st.session_state.messages = []
 
62
  for m in st.session_state.messages
63
  ],
64
  max_tokens=1024,
 
65
  stream=True,
66
  )
 
67
  response = st.write_stream(stream)
68
  except Exception as e:
69
  st.error(f"An error occurred while generating the response: {e}")