Sidoineko commited on
Commit
a5a74ec
·
verified ·
1 Parent(s): 29fa204

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +178 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,180 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
2
  import streamlit as st
3
+ from dotenv import load_dotenv
4
+ from huggingface_hub import InferenceClient
5
 
6
+ # -----------------------------------------------------------------------------
7
+ # Environment & constants
8
+ # -----------------------------------------------------------------------------
9
+ load_dotenv()
10
+ HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
11
+ model_id = "mistralai/Mistral-7B-Instruct-v0.3"
12
+
13
+ # -----------------------------------------------------------------------------
14
+ # LLM helper
15
+ # -----------------------------------------------------------------------------
16
+
17
+ def get_llm_hf_inference(model_id=model_id, max_new_tokens: int = 128, temperature: float = 0.1):
18
+ """Return an InferenceClient wrapper for Hugging Face inference."""
19
+ client = InferenceClient(model=model_id, token=HUGGINGFACEHUB_API_TOKEN)
20
+
21
+ def run(prompt: str) -> str:
22
+ try:
23
+ # For future versions with .conversational method
24
+ response = client.conversational(
25
+ inputs=prompt,
26
+ parameters={
27
+ "max_new_tokens": max_new_tokens,
28
+ "temperature": temperature,
29
+ },
30
+ )
31
+ return response.generated_text
32
+ except AttributeError:
33
+ # Fallback for older huggingface_hub clients
34
+ response = client.post(
35
+ json={
36
+ "inputs": prompt,
37
+ "parameters": {
38
+ "max_new_tokens": max_new_tokens,
39
+ "temperature": temperature,
40
+ },
41
+ },
42
+ task="conversational"
43
+ )
44
+ return response["generated_text"]
45
+
46
+ return run
47
+
48
+ # -----------------------------------------------------------------------------
49
+ # Streamlit page configuration
50
+ # -----------------------------------------------------------------------------
51
+ st.set_page_config(page_title="KolaChatBot", page_icon="🤗")
52
+ st.title("KolaChatBot")
53
+ st.markdown(
54
+ f"*KolaChatBot utilise l'API Inference de Hugging Face avec le modèle **{model_id}**.*"
55
+ )
56
+
57
+ # -----------------------------------------------------------------------------
58
+ # Session ‐state initialisation
59
+ # -----------------------------------------------------------------------------
60
+ if "avatars" not in st.session_state:
61
+ st.session_state.avatars = {"user": "👤", "assistant": "🤗"}
62
+
63
+ if "user_text" not in st.session_state:
64
+ st.session_state.user_text = None
65
+
66
+ if "max_response_length" not in st.session_state:
67
+ st.session_state.max_response_length = 256
68
+
69
+ if "system_message" not in st.session_state:
70
+ st.session_state.system_message = "You are a friendly AI conversing with a human user."
71
+
72
+ if "starter_message" not in st.session_state:
73
+ st.session_state.starter_message = "Hello, there! How can I help you today?"
74
+
75
+ # -----------------------------------------------------------------------------
76
+ # Sidebar settings
77
+ # -----------------------------------------------------------------------------
78
+ with st.sidebar:
79
+ st.header("Paramètres du système")
80
+
81
+ # AI Settings
82
+ st.session_state.system_message = st.text_area(
83
+ "System Message", value=st.session_state.system_message
84
+ )
85
+ st.session_state.starter_message = st.text_area(
86
+ "First AI Message", value=st.session_state.starter_message
87
+ )
88
+
89
+ # Model Settings
90
+ st.session_state.max_response_length = st.number_input(
91
+ "Max Response Length", value=st.session_state.max_response_length
92
+ )
93
+
94
+ # Avatar Selection
95
+ st.markdown("*Sélection des avatars :*")
96
+ col1, col2 = st.columns(2)
97
+ with col1:
98
+ st.session_state.avatars["assistant"] = st.selectbox(
99
+ "Avatar IA", options=["🤗", "💬", "🤖"], index=0
100
+ )
101
+ with col2:
102
+ st.session_state.avatars["user"] = st.selectbox(
103
+ "Avatar Utilisateur", options=["👤", "👱‍♂️", "👨🏾", "👩", "👧🏾"], index=0
104
+ )
105
+
106
+ # Reset Chat History
107
+ reset_history = st.button("Réinitialiser l'historique")
108
+
109
+ # -----------------------------------------------------------------------------
110
+ # Chat history initialisation / reset
111
+ # -----------------------------------------------------------------------------
112
+ if "chat_history" not in st.session_state or reset_history:
113
+ st.session_state.chat_history = [
114
+ {"role": "assistant", "content": st.session_state.starter_message}
115
+ ]
116
+
117
+ # -----------------------------------------------------------------------------
118
+ # Core inference helper
119
+ # -----------------------------------------------------------------------------
120
+
121
+ def build_prompt(system_message: str, chat_history: list[dict], user_text: str) -> str:
122
+ """Format the conversation as a prompt for the LLM."""
123
+ prompt = f"### SYSTEM:\n{system_message}\n\n"
124
+ for msg in chat_history:
125
+ role_tag = "USER" if msg["role"] == "user" else "ASSISTANT"
126
+ prompt += f"### {role_tag}:\n{msg['content']}\n\n"
127
+ prompt += f"### USER:\n{user_text}\n\n### ASSISTANT:\n"
128
+ return prompt
129
+
130
+
131
+ def get_response(system_message: str, chat_history: list[dict], user_text: str, max_new_tokens: int = 256):
132
+ """Generate a response and update chat history."""
133
+
134
+ prompt = build_prompt(system_message, chat_history, user_text)
135
+
136
+ llm = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
137
+ response_text = llm(prompt)
138
+
139
+ # Update history
140
+ chat_history.append({"role": "user", "content": user_text})
141
+ chat_history.append({"role": "assistant", "content": response_text})
142
+
143
+ return response_text, chat_history
144
+
145
+ # -----------------------------------------------------------------------------
146
+ # Streamlit chat interface
147
+ # -----------------------------------------------------------------------------
148
+ chat_interface = st.container(border=True)
149
+ with chat_interface:
150
+ output_container = st.container()
151
+ st.session_state.user_text = st.chat_input(placeholder="Entrez votre message ici…")
152
+
153
+ # Display chat messages
154
+ with output_container:
155
+ for message in st.session_state.chat_history:
156
+ if message["role"] == "system":
157
+ continue # Skip system messages
158
+ with st.chat_message(
159
+ message["role"], avatar=st.session_state.avatars[message["role"]]
160
+ ):
161
+ st.markdown(message["content"])
162
+
163
+ # Handle new user message
164
+ if st.session_state.user_text:
165
+ # Show the user message immediately
166
+ with st.chat_message("user", avatar=st.session_state.avatars["user"]):
167
+ st.markdown(st.session_state.user_text)
168
+
169
+ # Generate and display assistant response
170
+ with st.chat_message(
171
+ "assistant", avatar=st.session_state.avatars["assistant"]
172
+ ):
173
+ with st.spinner("KolaChatBot réfléchit…"):
174
+ response_text, st.session_state.chat_history = get_response(
175
+ system_message=st.session_state.system_message,
176
+ user_text=st.session_state.user_text,
177
+ chat_history=st.session_state.chat_history,
178
+ max_new_tokens=st.session_state.max_response_length,
179
+ )
180
+ st.markdown(response_text)