AmirFARES commited on
Commit
4ae9767
Β·
1 Parent(s): 1b078bf

Added multi model feature

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +61 -25
src/streamlit_app.py CHANGED
@@ -7,11 +7,30 @@ os.environ["STREAMLIT_HOME"] = "/tmp"
7
  os.environ["STREAMLIT_DISABLE_LOGGING"] = "1"
8
  os.environ["STREAMLIT_TELEMETRY_ENABLED"] = "0"
9
 
10
-
11
  # --- Streamlit page config ---
12
  st.set_page_config(page_title="Chat with Datamir Hub Assistant", page_icon="πŸ’¬")
13
  st.title("πŸ’¬ Chat with Datamir Hub Assistant")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # --- System prompt ---
16
  SYSTEM_PROMPT = """
17
  You are Datamir Hub Assistant, a friendly and knowledgeable AI assistant created to support data professionals, engineers, analysts, and businesses with their data and AI needs.
@@ -42,7 +61,8 @@ The standard consulting/freelance rate is $25/hour.
42
  Your goal is to ensure every user feels supported, empowered, and confident in their data and AI journey.
43
  """
44
 
45
-
 
46
 
47
  # --- Initialize session state ---
48
  if "messages" not in st.session_state:
@@ -71,10 +91,9 @@ if not st.session_state.chat_started:
71
  for i, suggestion in enumerate(suggestions):
72
  if cols[i].button(suggestion):
73
  st.session_state.pending_prompt = suggestion
74
- st.session_state.chat_started = True # <-- Fix: hide suggestions immediately
75
  st.rerun()
76
 
77
-
78
  # --- Input box ---
79
  prompt = st.chat_input("Type your message here...")
80
 
@@ -86,34 +105,51 @@ if st.session_state.pending_prompt:
86
  elif prompt:
87
  st.session_state.chat_started = True
88
 
89
- # --- Process prompt (either from suggestion or typed input) ---
90
  if prompt:
91
  st.session_state.messages.append({"role": "user", "content": prompt})
92
  with st.chat_message("user"):
93
  st.markdown(prompt)
94
 
95
- # --- Build chat prompt for model ---
96
- chat_prompt = SYSTEM_PROMPT.strip() + "\n\n"
97
- for msg in st.session_state.messages:
98
- role = "User" if msg["role"] == "user" else "Assistant"
99
- chat_prompt += f"{role}: {msg['content']}\n"
100
- chat_prompt += "Assistant:"
101
 
102
- # --- Call Hugging Face model ---
103
  try:
104
- # HF_API_TOKEN = os.getenv("hf_general_token")
105
- HF_API_TOKEN = os.getenv("bk_token")
106
- client = InferenceClient(api_key=HF_API_TOKEN, provider="hf-inference")
107
- response = client.text_generation(
108
- model="HuggingFaceH4/zephyr-7b-beta",
109
- prompt=chat_prompt,
110
- max_new_tokens=512,
111
- temperature=0.7,
112
- stop_sequences=["User:"],
113
- )
114
- model_reply = response.strip()
115
- if model_reply.endswith("User:"):
116
- model_reply = model_reply.rsplit("User:", 1)[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  except Exception as e:
118
  model_reply = f"❌ Failed to connect to the model: {e}"
119
  print(e)
 
7
  os.environ["STREAMLIT_DISABLE_LOGGING"] = "1"
8
  os.environ["STREAMLIT_TELEMETRY_ENABLED"] = "0"
9
 
 
10
  # --- Streamlit page config ---
11
  st.set_page_config(page_title="Chat with Datamir Hub Assistant", page_icon="πŸ’¬")
12
  st.title("πŸ’¬ Chat with Datamir Hub Assistant")
13
 
14
+ # --- Sidebar: Model selection ---
15
+ st.sidebar.title("βš™οΈ Settings")
16
+ MODEL_OPTIONS = {
17
+ "Zephyr 7B (HF)": {"model": "HuggingFaceH4/zephyr-7b-beta", "provider": "hf-inference"},
18
+ "Gemma 2B (Nebius)": {"model": "google/gemma-2-2b-it", "provider": "nebius"},
19
+ "Mistral Nemo Instruct (Nebius)": {"model": "mistralai/Mistral-Nemo-Instruct-2407", "provider": "nebius"},
20
+ "Mixtral 8x7B Instruct (Nebius)": {"model": "mistralai/Mixtral-8x7B-Instruct-v0.1", "provider": "nebius"},
21
+ "Command R+ (Cohere)": {"model": "CohereLabs/c4ai-command-r-plus", "provider": "cohere"},
22
+ "LLaMA 3 8B Instruct (Novita)": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "provider": "novita"},
23
+ "DeepSeek R1 (Nebius)": {"model": "deepseek-ai/DeepSeek-R1", "provider": "nebius"},
24
+ }
25
+
26
+ model_choice = st.sidebar.selectbox("Choose a model", list(MODEL_OPTIONS.keys()))
27
+
28
+ # --- Save model selection in session ---
29
+ if "selected_model" not in st.session_state:
30
+ st.session_state.selected_model = model_choice
31
+ else:
32
+ st.session_state.selected_model = model_choice
33
+
34
  # --- System prompt ---
35
  SYSTEM_PROMPT = """
36
  You are Datamir Hub Assistant, a friendly and knowledgeable AI assistant created to support data professionals, engineers, analysts, and businesses with their data and AI needs.
 
61
  Your goal is to ensure every user feels supported, empowered, and confident in their data and AI journey.
62
  """
63
 
64
+ # --- Display current model at top of chat ---
65
+ st.caption(f"πŸ€– Current Model: {st.session_state.selected_model}")
66
 
67
  # --- Initialize session state ---
68
  if "messages" not in st.session_state:
 
91
  for i, suggestion in enumerate(suggestions):
92
  if cols[i].button(suggestion):
93
  st.session_state.pending_prompt = suggestion
94
+ st.session_state.chat_started = True
95
  st.rerun()
96
 
 
97
  # --- Input box ---
98
  prompt = st.chat_input("Type your message here...")
99
 
 
105
  elif prompt:
106
  st.session_state.chat_started = True
107
 
108
+ # --- Process prompt ---
109
  if prompt:
110
  st.session_state.messages.append({"role": "user", "content": prompt})
111
  with st.chat_message("user"):
112
  st.markdown(prompt)
113
 
114
+ model_info = MODEL_OPTIONS[st.session_state.selected_model]
115
+ HF_API_TOKEN = os.getenv("bk_token")
 
 
 
 
116
 
 
117
  try:
118
+ client = InferenceClient(api_key=HF_API_TOKEN, provider=model_info["provider"])
119
+
120
+ if model_info["provider"] == "hf-inference":
121
+ # Format prompt as plain text conversation
122
+ chat_prompt = SYSTEM_PROMPT.strip() + "\n\n"
123
+ for msg in st.session_state.messages:
124
+ role = "User" if msg["role"] == "user" else "Assistant"
125
+ chat_prompt += f"{role}: {msg['content']}\n"
126
+ chat_prompt += "Assistant:"
127
+
128
+ response = client.text_generation(
129
+ model=model_info["model"],
130
+ prompt=chat_prompt,
131
+ max_new_tokens=512,
132
+ temperature=0.7,
133
+ stop_sequences=["User:"],
134
+ )
135
+ model_reply = response.strip()
136
+ if model_reply.endswith("User:"):
137
+ model_reply = model_reply.rsplit("User:", 1)[0].strip()
138
+
139
+ elif model_info["provider"] in ["nebius", "cohere", "novita"]:
140
+ # Use chat format for Nebius
141
+ response = client.chat.completions.create(
142
+ model=model_info["model"],
143
+ messages=[
144
+ {"role": "system", "content": SYSTEM_PROMPT.strip()},
145
+ *[
146
+ {"role": msg["role"], "content": msg["content"]}
147
+ for msg in st.session_state.messages
148
+ ],
149
+ ],
150
+ )
151
+ model_reply = response.choices[0].message.content.strip()
152
+
153
  except Exception as e:
154
  model_reply = f"❌ Failed to connect to the model: {e}"
155
  print(e)