Sidoineko commited on
Commit
2fa5206
·
verified ·
1 Parent(s): f75865a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +88 -145
src/streamlit_app.py CHANGED
@@ -8,28 +8,54 @@ from huggingface_hub import InferenceClient
8
  # -----------------------------------------------------------------------------
9
  load_dotenv()
10
  HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
11
-
12
- if not HUGGINGFACEHUB_API_TOKEN:
13
- st.error("Le token HUGGINGFACEHUB_API_TOKEN est introuvable. Vérifiez votre fichier .env.")
14
- st.stop()
15
-
16
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
17
 
18
  # -----------------------------------------------------------------------------
19
- # LLM helper - Modifié pour utiliser la tâche 'conversational'
20
- # La fonction get_llm_hf_inference n'est plus nécessaire dans sa forme originale.
21
- # On instanciera le client directement dans get_response.
22
  # -----------------------------------------------------------------------------
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # -----------------------------------------------------------------------------
25
  # Streamlit page configuration
26
  # -----------------------------------------------------------------------------
27
  st.set_page_config(page_title="KolaChatBot", page_icon="🤗")
28
  st.title("KolaChatBot")
29
- st.markdown(f"*KolaChatBot utilise l'API Inference de Hugging Face avec le modèle **{model_id}**.*")
 
 
30
 
31
  # -----------------------------------------------------------------------------
32
- # Sessionstate initialisation
33
  # -----------------------------------------------------------------------------
34
  if "avatars" not in st.session_state:
35
  st.session_state.avatars = {"user": "👤", "assistant": "🤗"}
@@ -53,12 +79,8 @@ with st.sidebar:
53
  st.header("Paramètres du système")
54
 
55
  # AI Settings
56
- # NOTE: Le message système pourrait ne pas être pris en charge directement par l'API 'conversational'.
57
- # Si le modèle en a besoin, il faudrait potentiellement l'inclure dans le premier message utilisateur
58
- # ou configurer l'endpoint différemment. Pour l'instant, on le garde comme réglage mais il n'est pas
59
- # passé directement dans l'appel API 'conversational'.
60
  st.session_state.system_message = st.text_area(
61
- "System Message (Non supporté directement par l'API 'conversational')", value=st.session_state.system_message
62
  )
63
  st.session_state.starter_message = st.text_area(
64
  "First AI Message", value=st.session_state.starter_message
@@ -66,7 +88,7 @@ with st.sidebar:
66
 
67
  # Model Settings
68
  st.session_state.max_response_length = st.number_input(
69
- "Max Response Length", value=st.session_state.max_response_length, min_value=1
70
  )
71
 
72
  # Avatar Selection
@@ -93,145 +115,66 @@ if "chat_history" not in st.session_state or reset_history:
93
  ]
94
 
95
  # -----------------------------------------------------------------------------
96
- # Core inference helper (Modifié pour la tâche 'conversational')
97
  # -----------------------------------------------------------------------------
98
 
99
- # La fonction build_prompt n'est plus utilisée avec la tâche 'conversational'
100
- # def build_prompt(...): pass
 
 
 
 
 
 
101
 
102
- def get_response(chat_history: list[dict], max_new_tokens: int = 256, temperature: float = 0.1):
103
- """
104
- Génère une réponse en utilisant la tâche 'conversational' de l'API Inference.
105
- Construit les inputs attendus par cette tâche à partir de l'historique.
106
- """
107
- # Instancier le client InferenceClient pour cet appel
108
- client = InferenceClient(model=model_id, token=HUGGINGFACEHUB_API_TOKEN)
109
 
110
- # Préparer l'historique pour le format de l'API 'conversational'
111
- # L'API attend un dictionnaire inputs avec :
112
- # {"text": "message utilisateur courant",
113
- # "past_user_inputs": ["ancien message user 1", "ancien message user 2", ...],
114
- # "generated_responses": ["ancienne réponse IA 1", "ancienne réponse IA 2", ...]}
115
-
116
- api_past_user_inputs = []
117
- api_generated_responses = []
118
-
119
- # Parcourir l'historique de chat_history, en excluant le dernier message
120
- # (qui est le message utilisateur courant) et les messages système.
121
- # On suppose que chat_history est dans l'ordre chronologique
122
- # [msg1, msg2, ..., dernier_message_utilisateur].
123
- # Donc l'historique "passé" pour l'API est tout sauf le dernier élément.
124
- history_for_api = [msg for msg in chat_history[:-1] if msg["role"] != "system"]
125
-
126
- # Construire les listes appariées pour l'API.
127
- # On suppose que history_for_api contient des messages alternés "user" et "assistant".
128
- temp_user_inputs = []
129
- temp_generated_responses = []
130
-
131
- for msg in history_for_api:
132
- if msg["role"] == "user":
133
- temp_user_inputs.append(msg["content"])
134
- elif msg["role"] == "assistant":
135
- temp_generated_responses.append(msg["content"])
136
-
137
- # L'API 'conversational' exige que past_user_inputs et generated_responses
138
- # aient la même longueur, représentant des tours de conversation complétés
139
- # (utilisateur -> assistant). On tronque si nécessaire (ne devrait pas l'être
140
- # si l'historique est bien géré).
141
- min_len = min(len(temp_user_inputs), len(temp_generated_responses))
142
- api_past_user_inputs = temp_user_inputs[:min_len]
143
- api_generated_responses = temp_generated_responses[:min_len]
144
-
145
-
146
- # Le message utilisateur courant est le contenu du dernier message dans chat_history
147
- current_user_input_api = chat_history[-1]["content"]
148
-
149
- # Appeler la tâche 'conversational' en utilisant l'instance client
150
- try:
151
- response = client.conversational( # Utiliser client.conversational comme indiqué par le ValueError
152
- inputs={
153
- "text": current_user_input_api, # Message utilisateur courant
154
- "past_user_inputs": api_past_user_inputs, # Liste des anciens messages utilisateur
155
- "generated_responses": api_generated_responses, # Liste des anciennes réponses de l'IA
156
- },
157
- parameters={
158
- "max_new_tokens": max_new_tokens,
159
- "temperature": temperature,
160
- # Ajouter d'autres paramètres supportés par l'API pour la tâche conversational si besoin
161
- # (ex: repetition_penalty, do_sample, top_k, top_p etc.)
162
- # Consulter la documentation de l'API Inference Hugging Face pour le modèle/tâche spécifique.
163
- },
164
- )
165
- # La réponse de la tâche 'conversational' est un objet avec l'attribut generated_text
166
- response_text = response.generated_text
167
- except Exception as e:
168
- # Afficher l'erreur et retourner un message d'erreur dans le chat
169
- print(f"Error during conversational API call: {e}")
170
- response_text = f"Une erreur est survenue lors de la génération de la réponse : {e}"
171
- # Retourne un message d'erreur pour que le tour de conversation soit visuellement complété.
172
 
173
- # Retourner le texte généré (ou le message d'erreur)
174
- # La boucle principale Streamlit gérera l'ajout de cette réponse à st.session_state.chat_history
175
- return response_text
176
 
 
 
 
 
 
177
 
178
  # -----------------------------------------------------------------------------
179
- # Streamlit chat interface (Boucle principale)
180
  # -----------------------------------------------------------------------------
181
  chat_interface = st.container(border=True)
182
  with chat_interface:
183
- output_container = st.container() # Conteneur où les messages sont affichés
184
-
185
- # Affichage des messages de l'historique
186
- # Cette boucle s'exécute à chaque redémarrage du script Streamlit.
187
- with output_container:
188
- for message in st.session_state.chat_history:
189
- if message["role"] == "system":
190
- continue # Ne pas afficher les messages de rôle 'system'
191
- # Utiliser .get pour éviter une erreur si l'avatar n'était pas trouvé (sécurité)
192
- with st.chat_message(
193
- message["role"], avatar=st.session_state.avatars.get(message["role"], "❓")
194
- ):
195
- st.markdown(message["content"])
196
-
197
- # Champ de saisie pour l'utilisateur
198
- # Lorsque l'utilisateur entre du texte, st.session_state.user_text est mis à jour
199
- # et le script Streamlit redémarre depuis le début.
200
  st.session_state.user_text = st.chat_input(placeholder="Entrez votre message ici…")
201
 
202
- # Ce bloc s'exécute si le script a redémarré parce que st.session_state.user_text a été mis à jour.
203
- if st.session_state.user_text:
204
- # 1. Ajouter le nouveau message utilisateur à l'historique de l'état de session.
205
- # Cet ajout rend le message utilisateur visible lors du prochain redémarrage du script
206
- # (qui arrivera après que la réponse de l'IA soit générée et ajoutée).
207
- st.session_state.chat_history.append({"role": "user", "content": st.session_state.user_text})
208
-
209
- # 2. Afficher l'indicateur de chargement pour la réponse de l'IA.
210
- # Le message utilisateur vient d'être ajouté à st.session_state.chat_history,
211
- # il est donc inclus lorsque get_response analyse l'historique.
212
- with st.chat_message(
213
- "assistant", avatar=st.session_state.avatars["assistant"]
214
- ):
215
- with st.spinner("KolaChatBot réfléchit…"):
216
- # Appeler get_response pour générer le texte de l'IA.
217
- # get_response utilise maintenant correctement la tâche 'conversational'
218
- # et construit le dictionnaire d'entrée pour l'API à partir de l'historique mis à jour.
219
- response_text = get_response(
220
- chat_history=st.session_state.chat_history, # Passer l'historique *complet* (incluant le dernier message user)
221
- max_new_tokens=st.session_state.max_response_length,
222
- temperature=0.1, # Ou récupérer depuis la sidebar si réglable
223
- )
224
- # 3. Ajouter la réponse générée par l'IA à l'historique de l'état de session.
225
- # Cet ajout rend le message de l'IA visible lors du prochain et dernier redémarrage.
226
- st.session_state.chat_history.append({"role": "assistant", "content": response_text})
227
-
228
- # 4. Afficher la réponse de l'IA immédiatement pour une meilleure expérience utilisateur.
229
- st.markdown(response_text)
230
-
231
- # 5. Nettoyer l'état de la zone de saisie après que le message ait été traité.
232
- # Cela empêche le message d'être traité à nouveau lors des redémarrages ultérieurs
233
- # (ex: interactions dans la sidebar) et vide le champ de saisie.
234
- st.session_state.user_text = None
235
-
236
- # Streamlit redémarrera automatiquement le script lorsque st.session_state.chat_history change
237
- # ou lorsque l'utilisateur entre du texte, mettant à jour l'affichage dans output_container.
 
8
  # -----------------------------------------------------------------------------
9
  load_dotenv()
10
  HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
 
 
 
 
11
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
12
 
13
  # -----------------------------------------------------------------------------
14
+ # LLM helper
 
 
15
  # -----------------------------------------------------------------------------
16
 
17
+ def get_llm_hf_inference(model_id=model_id, max_new_tokens: int = 128, temperature: float = 0.1):
18
+ """Return an InferenceClient wrapper for Hugging Face inference."""
19
+ client = InferenceClient(model=model_id, token=HUGGINGFACEHUB_API_TOKEN)
20
+
21
+ def run(prompt: str) -> str:
22
+ try:
23
+ # For future versions with .conversational method
24
+ response = client.conversational(
25
+ inputs=prompt,
26
+ parameters={
27
+ "max_new_tokens": max_new_tokens,
28
+ "temperature": temperature,
29
+ },
30
+ )
31
+ return response.generated_text
32
+ except AttributeError:
33
+ # Fallback for older huggingface_hub clients
34
+ response = client.post(
35
+ json={
36
+ "inputs": prompt,
37
+ "parameters": {
38
+ "max_new_tokens": max_new_tokens,
39
+ "temperature": temperature,
40
+ },
41
+ },
42
+ task="conversational"
43
+ )
44
+ return response["generated_text"]
45
+
46
+ return run
47
+
48
  # -----------------------------------------------------------------------------
49
  # Streamlit page configuration
50
  # -----------------------------------------------------------------------------
51
  st.set_page_config(page_title="KolaChatBot", page_icon="🤗")
52
  st.title("KolaChatBot")
53
+ st.markdown(
54
+ f"*KolaChatBot utilise l'API Inference de Hugging Face avec le modèle **{model_id}**.*"
55
+ )
56
 
57
  # -----------------------------------------------------------------------------
58
+ # Session ‐state initialisation
59
  # -----------------------------------------------------------------------------
60
  if "avatars" not in st.session_state:
61
  st.session_state.avatars = {"user": "👤", "assistant": "🤗"}
 
79
  st.header("Paramètres du système")
80
 
81
  # AI Settings
 
 
 
 
82
  st.session_state.system_message = st.text_area(
83
+ "System Message", value=st.session_state.system_message
84
  )
85
  st.session_state.starter_message = st.text_area(
86
  "First AI Message", value=st.session_state.starter_message
 
88
 
89
  # Model Settings
90
  st.session_state.max_response_length = st.number_input(
91
+ "Max Response Length", value=st.session_state.max_response_length
92
  )
93
 
94
  # Avatar Selection
 
115
  ]
116
 
117
  # -----------------------------------------------------------------------------
118
+ # Core inference helper
119
  # -----------------------------------------------------------------------------
120
 
121
+ def build_prompt(system_message: str, chat_history: list[dict], user_text: str) -> str:
122
+ """Format the conversation as a prompt for the LLM."""
123
+ prompt = f"### SYSTEM:\n{system_message}\n\n"
124
+ for msg in chat_history:
125
+ role_tag = "USER" if msg["role"] == "user" else "ASSISTANT"
126
+ prompt += f"### {role_tag}:\n{msg['content']}\n\n"
127
+ prompt += f"### USER:\n{user_text}\n\n### ASSISTANT:\n"
128
+ return prompt
129
 
 
 
 
 
 
 
 
130
 
131
+ def get_response(system_message: str, chat_history: list[dict], user_text: str, max_new_tokens: int = 256):
132
+ """Generate a response and update chat history."""
133
+
134
+ prompt = build_prompt(system_message, chat_history, user_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ llm = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
137
+ response_text = llm(prompt)
 
138
 
139
+ # Update history
140
+ chat_history.append({"role": "user", "content": user_text})
141
+ chat_history.append({"role": "assistant", "content": response_text})
142
+
143
+ return response_text, chat_history
144
 
145
  # -----------------------------------------------------------------------------
146
+ # Streamlit chat interface
147
  # -----------------------------------------------------------------------------
148
  chat_interface = st.container(border=True)
149
  with chat_interface:
150
+ output_container = st.container()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  st.session_state.user_text = st.chat_input(placeholder="Entrez votre message ici…")
152
 
153
+ # Display chat messages
154
+ with output_container:
155
+ for message in st.session_state.chat_history:
156
+ if message["role"] == "system":
157
+ continue # Skip system messages
158
+ with st.chat_message(
159
+ message["role"], avatar=st.session_state.avatars[message["role"]]
160
+ ):
161
+ st.markdown(message["content"])
162
+
163
+ # Handle new user message
164
+ if st.session_state.user_text:
165
+ # Show the user message immediately
166
+ with st.chat_message("user", avatar=st.session_state.avatars["user"]):
167
+ st.markdown(st.session_state.user_text)
168
+
169
+ # Generate and display assistant response
170
+ with st.chat_message(
171
+ "assistant", avatar=st.session_state.avatars["assistant"]
172
+ ):
173
+ with st.spinner("KolaChatBot réfléchit…"):
174
+ response_text, st.session_state.chat_history = get_response(
175
+ system_message=st.session_state.system_message,
176
+ user_text=st.session_state.user_text,
177
+ chat_history=st.session_state.chat_history,
178
+ max_new_tokens=st.session_state.max_response_length,
179
+ )
180
+ st.markdown(response_text)