diff --git "a/src/streamlit_app.py" "b/src/streamlit_app.py" --- "a/src/streamlit_app.py" +++ "b/src/streamlit_app.py" @@ -1,11 +1,14 @@ +ument ! Voici le code Python complet avec la correction appliquée à la ligne 805. + import os import streamlit as st import requests from dotenv import load_dotenv -import json -from datetime import datetime -from PIL import Image -from io import BytesIO +import json # Pour l'export JSON +from datetime import datetime # Pour nommer les fichiers d'export +import time # Pour les messages d'attente si besoin (non strictement utilisé ici mais utile) +from PIL import Image # Pour potentiellement ouvrir/manipuler l'image si nécessaire (basé sur pillow) +from io import BytesIO # Pour lire les données binaires de l'image # Import Google Generative AI library import google.generativeai as genai @@ -15,106 +18,91 @@ import google.generativeai as genai # ----------------------------------------------------------------------------- load_dotenv() HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") -GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") +GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # Add Google API Key # API URL base for Hugging Face Inference API HF_BASE_API_URL = "https://api-inference.huggingface.co/models/" -# --- Define Constants for Clarity and Maintainability --- -TASK_TEXT = 'text' -TASK_IMAGE = 'image' # Renamed from 't2i' for consistency with UI radio button - -PROVIDER_HF = 'huggingface' -PROVIDER_GOOGLE = 'google' - -ROLE_USER = 'user' -ROLE_ASSISTANT = 'assistant' -ROLE_SYSTEM = 'system' # Explicitly define system role - -TYPE_TEXT = 'text' -TYPE_T2I = 't2i' -TYPE_T2I_PROMPT = 't2i_prompt' # User input for image generation - -# Keys for message dictionary -ROLE_KEY = 'role' -CONTENT_KEY = 'content' -TYPE_KEY = 'type' -PROMPT_KEY = 'prompt' -IMAGE_DATA_KEY = 'image_data' - # ----------------------------------------------------------------------------- # Define Available Models and Default Parameters per Task Type # ----------------------------------------------------------------------------- +# Structure: {'id': 'model_id', 'name': 'Display Name', 'provider': 'huggingface'/'google', 'type': 'text'/'t2i', 'params': {...} } AVAILABLE_MODELS = [ - { - "id": "mistralai/Mistral-7B-Instruct-v0.3", - "name": "Mistral 7B Instruct v0.3 (Hugging Face)", - "provider": PROVIDER_HF, - "type": TASK_TEXT, - "params": { # Default parameters for text models - "max_new_tokens": 200, - "temperature": 0.6, - "top_p": 0.9, - }, + { + "id": "mistralai/Mistral-7B-Instruct-v0.3", + "name": "Mistral 7B Instruct v0.3 (Hugging Face)", + "provider": "huggingface", + "type": "text", + "params": { # Default parameters for text models + "max_new_tokens": 200, + "temperature": 0.6, + "top_p": 0.9, }, - # Google Text Models - { - "id": "gemini-1.5-flash-latest", - "name": "Gemini 1.5 Flash (Google)", - "provider": PROVIDER_GOOGLE, - "type": TASK_TEXT, - "params": { - "max_new_tokens": 200, - "temperature": 0.6, - "top_p": 0.9, - }, + }, + # Google Text Models + { + "id": "gemini-1.5-flash-latest", + "name": "Gemini 1.5 Flash (Google)", + "provider": "google", + "type": "text", + "params": { + "max_new_tokens": 200, + "temperature": 0.6, + "top_p": 0.9, }, - { # Using -latest aliases is generally better, but user asked for explicit preview - "id": "gemini-2.5-flash-preview-05-20", - "name": "Gemini 2.5 Flash Preview (Google)", - "provider": PROVIDER_GOOGLE, - "type": TASK_TEXT, - "params": { - "max_new_tokens": 200, - "temperature": 0.6, - "top_p": 0.9, - }, + }, + { # Using -latest aliases is generally better, but user asked for explicit preview + "id": "gemini-2.5-flash-preview-05-20", + "name": "Gemini 2.5 Flash Preview (Google)", + "provider": "google", + "type": "text", + "params": { + "max_new_tokens": 200, + "temperature": 0.6, + "top_p": 0.9, }, - { - "id": "gemini-2.5-pro-preview-05-06", - "name": "Gemini 2.5 Pro Preview (Google)", - "provider": PROVIDER_GOOGLE, - "type": TASK_TEXT, - "params": { - "max_new_tokens": 200, - "temperature": 0.6, - "top_p": 0.9, - }, + }, + { + "id": "gemini-2.5-pro-preview-05-06", + "name": "Gemini 2.5 Pro Preview (Google)", + "provider": "google", + "type": "text", + "params": { + "max_new_tokens": 200, + "temperature": 0.6, + "top_p": 0.9, }, - # Hugging Face Text-to-Image Models - { - "id": "black-forest-labs/FLUX.1-dev", - "name": "FLUX.1-dev (Text-to-Image HF)", - "provider": PROVIDER_HF, - "type": TASK_IMAGE, - "params": { # Default parameters for T2I models (example params) - "num_inference_steps": 50, # Common diffusion param (e.g., steps) - "guidance_scale": 7.5, # Common diffusion param (e.g., CFG scale) - "height": 512, # Image height - "width": 512, # Image width - "negative_prompt": "", # T2I specific (what NOT to include) - }, + }, + # Hugging Face Text-to-Image Models + { + "id": "black-forest-labs/FLUX.1-dev", + "name": "FLUX.1-dev (Text-to-Image HF)", + "provider": "huggingface", + "type": "t2i", + "params": { # Default parameters for T2I models (example params) + "num_inference_steps": 50, # Common diffusion param (e.g., steps) + "guidance_scale": 7.5, # Common diffusion param (e.g., CFG scale) + "height": 512, # Image height + "width": 512, # Image width + "negative_prompt": "", # T2I specific (what NOT to include) }, + }, + # You could add other T2I models here following the same structure + # { + # "id": "runwayml/stable-diffusion-v1-5", + # "name": "Stable Diffusion 1.5 (Hugging Face)", + # "provider": "huggingface", + # "type": "t2i", + # "params": { ... }, + # }, ] # Separate models lists by type for easier filtering MODELS_BY_TYPE = {m_type: [m for m in AVAILABLE_MODELS if m['type'] == m_type] for m_type in set(m['type'] for m in AVAILABLE_MODELS)} # Default task and model (select the first text model by default) -DEFAULT_TASK = TASK_TEXT -DEFAULT_MODEL_INFO = MODELS_BY_TYPE.get(DEFAULT_TASK, [None])[0] -DEFAULT_MODEL_ID = DEFAULT_MODEL_INFO['id'] if DEFAULT_MODEL_INFO else None - +DEFAULT_TASK = 'text' +DEFAULT_MODEL_ID = MODELS_BY_TYPE.get(DEFAULT_TASK, [None])[0]['id'] if MODELS_BY_TYPE.get(DEFAULT_TASK) and MODELS_BY_TYPE.get(DEFAULT_TASK)[0] else None # ----------------------------------------------------------------------------- # Session-state initialisation (MOVED TO THE TOP AND EXPANDED) @@ -125,836 +113,944 @@ DEFAULT_STARTER_MESSAGE = "Bonjour ! Je suis KolaChatBot. Choisissez la tâche ( # Ensure initial task is set if 'selected_task' not in st.session_state: - st.session_state.selected_task = DEFAULT_TASK + st.session_state.selected_task = DEFAULT_TASK # Ensure initial model ID is set based on default task, fallback if needed if 'selected_model_id' not in st.session_state: - # Attempt to get the default model for the current (or default) task - default_model_for_current_task = MODELS_BY_TYPE.get(st.session_state.selected_task, [None])[0] - st.session_state.selected_model_id = default_model_for_current_task['id'] if default_model_for_current_task else None + # Attempt to get the default model for the current (or default) task + default_model_for_current_task = MODELS_BY_TYPE.get(st.session_state.selected_task, [None])[0] + st.session_state.selected_model_id = default_model_for_current_task['id'] if default_model_for_current_task else None # Ensure parameters specific to the selected task type are initialized with defaults from the model selected_model_info_initial = next((m for m in AVAILABLE_MODELS if m['id'] == st.session_state.selected_model_id), None) if selected_model_info_initial: - default_params_for_model = selected_model_info_initial.get('params', {}) - # Initialize individual parameters in session state based on model defaults - # This makes it easier to link UI widgets directly - if selected_model_info_initial['type'] == TASK_TEXT: - if 'max_response_length' not in st.session_state: st.session_state.max_response_length = default_params_for_model.get("max_new_tokens", 200) - if 'temperature' not in st.session_state: st.session_state.temperature = default_params_for_model.get("temperature", 0.6) - if 'top_p' not in st.session_state: st.session_state.top_p = default_params_for_model.get("top_p", 0.9) - # Initialize T2I params to None or default for T2I (they won't be displayed/used for text task) - if 'num_inference_steps' not in st.session_state: st.session_state.num_inference_steps = None - if 'guidance_scale' not in st.session_state: st.session_state.guidance_scale = None - if 'image_height' not in st.session_state: st.session_state.image_height = None - if 'image_width' not in st.session_state: st.session_state.image_width = None - if 'negative_prompt' not in st.session_state: st.session_state.negative_prompt = "" - - elif selected_model_info_initial['type'] == TASK_IMAGE: - # Initialize T2I specific parameters - if 'num_inference_steps' not in st.session_state: st.session_state.num_inference_steps = default_params_for_model.get("num_inference_steps", 50) - if 'guidance_scale' not in st.session_state: st.session_state.guidance_scale = default_params_for_model.get("guidance_scale", 7.5) - if 'image_height' not in st.session_state: st.session_state.image_height = default_params_for_model.get("height", 512) - if 'image_width' not in st.session_state: st.session_state.image_width = default_params_for_model.get("width", 512) - if 'negative_prompt' not in st.session_state: st.session_state.negative_prompt = default_params_for_model.get("negative_prompt", "") - # Initialize Text params to None - if 'max_response_length' not in st.session_state: st.session_state.max_response_length = None - if 'temperature' not in st.session_state: st.session_state.temperature = None - if 'top_p' not in st.session_state: st.session_state.top_p = None -else: # Fallback if no model was initially found (shouldn't happen with DEFAULT_MODEL_ID setup) - if 'max_response_length' not in st.session_state: st.session_state.max_response_length = 200 - if 'temperature' not in st.session_state: st.session_state.temperature = 0.6 - if 'top_p' not in st.session_state: st.session_state.top_p = 0.9 - if 'num_inference_steps' not in st.session_state: st.session_state.num_inference_steps = 50 - if 'guidance_scale' not in st.session_state: st.session_state.guidance_scale = 7.5 - if 'image_height' not in st.session_state: st.session_state.image_height = 512 - if 'image_width' not in st.session_state: st.session_state.image_width = 512 + default_params_for_model = selected_model_info_initial.get('params', {}) + # Initialize individual parameters in session state based on model defaults + # This makes it easier to link UI widgets directly + if selected_model_info_initial['type'] == 'text': + if 'max_response_length' not in st.session_state: st.session_state.max_response_length = default_params_for_model.get("max_new_tokens", 200) + if 'temperature' not in st.session_state: st.session_state.temperature = default_params_for_model.get("temperature", 0.6) + if 'top_p' not in st.session_state: st.session_state.top_p = default_params_for_model.get("top_p", 0.9) + # Initialize T2I params to None or default *T2I* defaults if not selected type + if 'num_inference_steps' not in st.session_state: st.session_state.num_inference_steps = None + if 'guidance_scale' not in st.session_state: st.session_state.guidance_scale = None + if 'image_height' not in st.session_state: st.session_state.image_height = None + if 'image_width' not in st.session_state: st.session_state.image_width = None if 'negative_prompt' not in st.session_state: st.session_state.negative_prompt = "" + elif selected_model_info_initial['type'] == 't2i': + # Initialize T2I specific parameters + if 'num_inference_steps' not in st.session_state: st.session_state.num_inference_steps = default_params_for_model.get("num_inference_steps", 50) + if 'guidance_scale' not in st.session_state: st.session_state.guidance_scale = default_params_for_model.get("guidance_scale", 7.5) + if 'image_height' not in st.session_state: st.session_state.image_height = default_params_for_model.get("height", 512) + if 'image_width' not in st.session_state: st.session_state.image_width = default_params_for_model.get("width", 512) + if 'negative_prompt' not in st.session_state: st.session_state.negative_prompt = default_params_for_model.get("negative_prompt", "") + # Initialize Text params to None or default *Text* defaults if not selected type + if 'max_response_length' not in st.session_state: st.session_state.max_response_length = None # Or set to a default text param value if needed + if 'temperature' not in st.session_state: st.session_state.temperature = None + if 'top_p' not in st.session_state: st.session_state.top_p = None +else: # Fallback if no model was initially found + if 'max_response_length' not in st.session_state: st.session_state.max_response_length = 200 + if 'temperature' not in st.session_state: st.session_state.temperature = 0.6 + if 'top_p' not in st.session_state: st.session_state.top_p = 0.9 + if 'num_inference_steps' not in st.session_state: st.session_state.num_inference_steps = 50 + if 'guidance_scale' not in st.session_state: st.session_state.guidance_scale = 7.5 + if 'image_height' not in st.session_state: st.session_state.image_height = 512 + if 'image_width' not in st.session_state: st.session_state.image_width = 512 + if 'negative_prompt' not in st.session_state: st.session_state.negative_prompt = "" + # Initialize system and starter messages in session state if not exists if "system_message" not in st.session_state: - st.session_state.system_message = DEFAULT_SYSTEM_MESSAGE + st.session_state.system_message = DEFAULT_SYSTEM_MESSAGE if "starter_message" not in st.session_state: - st.session_state.starter_message = DEFAULT_STARTER_MESSAGE + st.session_state.starter_message = DEFAULT_STARTER_MESSAGE # Initialize avatars if "avatars" not in st.session_state: - st.session_state.avatars = {ROLE_USER: "👤", ROLE_ASSISTANT: "🤖"} + st.session_state.avatars = {"user": "👤", "assistant": "🤖"} # Initialize chat history in session state if not exists. if "chat_history" not in st.session_state: - st.session_state.chat_history = [{ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: st.session_state.starter_message, TYPE_KEY: TYPE_TEXT}] + st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message, "type": "text"}] # Add type -# Store previous task and model ID to detect changes for auto-rerun and history reset +# Store previous task and model ID to detect changes if '_prev_task' not in st.session_state: st.session_state._prev_task = st.session_state.selected_task if '_prev_model_id' not in st.session_state: st.session_state._prev_model_id = st.session_state.selected_model_id -if '_reset_triggered' not in st.session_state: st.session_state._reset_triggered = False - # ----------------------------------------------------------------------------- -# Helper for export formatting +# Helper pour formater les exports (Adapté pour images) # ----------------------------------------------------------------------------- def format_history_to_txt(chat_history: list[dict]) -> str: - """Formats chat history into a plain text string.""" - lines = [f"KolaChatBot Conversation (Export Texte Uniquement) - Exporté le {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"] - for message in chat_history: - if message.get(ROLE_KEY) == ROLE_SYSTEM: continue # Skip system messages + lines = [f"KolaChatBot Conversation (Export Texte Uniquement) - Exporté le {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"] + for message in chat_history: + if message.get("role") == "system": continue + role = "Utilisateur" if message["role"] == "user" else "KolaChatBot" + content_type = message.get("type", "text") + if content_type == "text": + lines.append(f"{role}:\n{message['content']}\n") + elif content_type == "t2i" and "prompt" in message: + lines.append(f"{role} (Image générée):\nPrompt: {message['prompt']}\n(Image non incluse dans l'export TXT)\n") + # Ignore other types gracefully + + return "".join(lines) - role_label = "Utilisateur" if message[ROLE_KEY] == ROLE_USER else "KolaChatBot" - content_type = message.get(TYPE_KEY, TYPE_TEXT) +def format_history_to_json(chat_history: list[dict]) -> str: + # Export JSON can include image data (base64 or placeholder) or just text. + # Storing raw bytes directly isn't JSON serializable. Base64 is possible but makes JSON huge. + # Let's export text content and image prompts, but not the image data itself. + export_history = [] + for message in chat_history: + if message.get("role") == "system": continue + export_msg = {"role": message["role"], "type": message.get("type", "text")} + if export_msg["type"] == "text": + export_msg["content"] = message.get("content", "") + elif export_msg["type"] == "t2i" and "prompt" in message: + export_msg["prompt"] = message["prompt"] + export_msg["image_placeholder"] = "(Image non incluse dans l'export JSON)" # Indicate image was here + export_history.append(export_msg) + + return json.dumps(export_history, indent=2, ensure_ascii=False) - if content_type == TYPE_TEXT: - lines.append(f"{role_label}:\n{message.get(CONTENT_KEY, '')}\n\n") - elif content_type == TYPE_T2I and PROMPT_KEY in message: - lines.append(f"{role_label} (Image générée):\nPrompt: {message[PROMPT_KEY]}\n(Image non incluse dans l'export TXT)\n\n") - elif content_type == TYPE_T2I_PROMPT: # User's text prompt for image generation - lines.append(f"{role_label} (Prompt Image):\n{message.get(CONTENT_KEY, '')}\n\n") - # Ignore other types gracefully +def format_history_to_md(chat_history: list[dict]) -> str: + lines = [f"# KolaChatBot Conversation (Export Texte Uniquement)\n*Exporté le {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*\n\n"] + for message in chat_history: + if message.get("role") == "system": continue + user_avatar = st.session_state.avatars.get('user', '👤') + assistant_avatar = st.session_state.avatars.get('assistant', '🤖') + role_label = f"{user_avatar} **Utilisateur**" if message["role"] == "user" else f"{assistant_avatar} **KolaChatBot**" - return "".join(lines) + content_type = message.get("type", "text") -def format_history_to_json(chat_history: list[dict]) -> str: - """Formats chat history into a JSON string.""" - export_history = [] - for message in chat_history: - if message.get(ROLE_KEY) == ROLE_SYSTEM: continue - - export_msg = {ROLE_KEY: message[ROLE_KEY], TYPE_KEY: message.get(TYPE_KEY, TYPE_TEXT)} - - if export_msg[TYPE_KEY] == TYPE_TEXT: - export_msg[CONTENT_KEY] = message.get(CONTENT_KEY, "") - elif export_msg[TYPE_KEY] == TYPE_T2I and PROMPT_KEY in message: - export_msg[PROMPT_KEY] = message[PROMPT_KEY] - export_msg["image_placeholder"] = "(Image non incluse dans l'export JSON)" # Indicate image was here - # Add other relevant T2I params if desired - elif export_msg[TYPE_KEY] == TYPE_T2I_PROMPT: - export_msg[CONTENT_KEY] = message.get(CONTENT_KEY, "") # Store the user's prompt text - - export_history.append(export_msg) - - return json.dumps(export_history, indent=2, ensure_ascii=False) + lines.append(f"### {role_label}\n\n") + if content_type == "text": + lines.append(f"{message.get('content', '')}\n\n") + elif content_type == "t2i" and "prompt" in message: + lines.append(f"*Image générée (prompt: {message['prompt']})*\n") + # In actual markdown, you could embed a base64 image or link if saved. + # For this simple export, just note it. + lines.append("![Image non incluse dans l'export Markdown](placeholder.png)\n\n") -def format_history_to_md(chat_history: list[dict]) -> str: - """Formats chat history into a Markdown string.""" - lines = [f"# KolaChatBot Conversation (Export Texte Uniquement)\n*Exporté le {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*\n\n"] - for message in chat_history: - if message.get(ROLE_KEY) == ROLE_SYSTEM: continue - - user_avatar = st.session_state.avatars.get(ROLE_USER, '👤') - assistant_avatar = st.session_state.avatars.get(ROLE_ASSISTANT, '🤖') - - role_icon = user_avatar if message[ROLE_KEY] == ROLE_USER else assistant_avatar - role_label = f" **{role_icon} Utilisateur**" if message[ROLE_KEY] == ROLE_USER else f" **{role_icon} KolaChatBot**" - - content_type = message.get(TYPE_KEY, TYPE_TEXT) - - lines.append(f"## {role_label}\n\n") - if content_type == TYPE_TEXT: - lines.append(f"{message.get(CONTENT_KEY, '')}\n\n") - elif content_type == TYPE_T2I and PROMPT_KEY in message: - lines.append(f"*Image générée (prompt: {message[PROMPT_KEY]})*\n") - lines.append("![Image non incluse dans l'export Markdown](placeholder.png)\n\n") - elif content_type == TYPE_T2I_PROMPT: - lines.append(f"*Prompt Image :*\n{message.get(CONTENT_KEY, '')}\n\n") - - lines.append("---\n\n") - - return "".join(lines) + lines.append("---\n\n") + + return "".join(lines) # ----------------------------------------------------------------------------- # LLM API helper (Unified call logic) # ----------------------------------------------------------------------------- +# This function is specifically for text prompt building for Mistral-style models def build_mistral_prompt(system_message: str, chat_history_for_prompt: list[dict]) -> str: - """ - Builds a prompt string suitable for Mistral-style instruction models. - The `chat_history_for_prompt` should be a list of text messages, ordered - from oldest to newest, including the latest user message. - """ - prompt_parts = [] - - # Check if there are any actual messages to process - if not chat_history_for_prompt: - return "" - - has_system_instruction = bool(system_message) - - for i, message in enumerate(chat_history_for_prompt): - # Only include text messages in the prompt - if message.get(TYPE_KEY, TYPE_TEXT) != TYPE_TEXT: - continue - - role = message[ROLE_KEY] - content = message[CONTENT_KEY] - - if role == ROLE_USER: - # For the very first user message, incorporate system message if present - if i == 0 and has_system_instruction: - prompt_parts.append(f"[INST] <>\n{system_message}\n<>\n\n{content} [/INST]") - has_system_instruction = False # Ensure system message is only added once - else: - # For subsequent user messages (or first if no system message) - prompt_parts.append(f"[INST] {content} [/INST]") - elif role == ROLE_ASSISTANT: - # Assistant responses close the previous instruction turn - prompt_parts.append(f" {content}") - - return "".join(prompt_parts) - - -def call_hf_inference(model_id: str, payload_inputs: any, params: dict, model_type: str) -> any: - """ - Calls the Hugging Face Inference API for either text generation or text-to-image. - payload_inputs: The main input for the model (string for text, string for t2i prompt). - params: Dictionary of model-specific parameters. - model_type: 'text' or 'image' (TASK_TEXT or TASK_IMAGE constant). - Returns: string for text, bytes for image, or an error string. - """ - if not HUGGINGFACEHUB_API_TOKEN: - return "Erreur d'API Hugging Face: Le token HUGGINGFACEHUB_API_TOKEN est introuvable." - - headers = {"Authorization": f"Bearer {HUGGINGFACEHUB_API_TOKEN}"} - url = f"{HF_BASE_API_URL}{model_id}" - - payload = {} - response_parser = None - response_is_json = False # Will be set based on model_type - - if model_type == TASK_TEXT: - payload = { - "inputs": payload_inputs, # The prompt string built elsewhere - "parameters": { - "max_new_tokens": params.get("max_new_tokens", 200), - "temperature": params.get("temperature", 0.6), - "top_p": params.get("top_p", 0.9), - "return_full_text": False, - "num_return_sequences": 1, - "do_sample": params.get("temperature", 0.6) > 1e-2, # Set do_sample if temperature is not zero - }, - "options": {"wait_for_model": True, "use_cache": False} - } - def parse_text_response(response_data: dict) -> str: - if isinstance(response_data, list) and response_data and CONTENT_KEY in response_data[0]: - return response_data[0][CONTENT_KEY].strip() - else: - return f"Erreur API Hugging Face (Format): Réponse texte inattendue - {response_data}" - response_parser = parse_text_response - response_is_json = True - - elif model_type == TASK_IMAGE: - payload = { - "inputs": payload_inputs, # The prompt string (user input) - "parameters": { - "negative_prompt": params.get("negative_prompt", ""), - "num_inference_steps": params.get("num_inference_steps", 50), - "guidance_scale": params.get("guidance_scale", 7.5), - "height": params.get("image_height", 512), - "width": params.get("image_width", 512), - }, - "options": {"wait_for_model": True, "use_cache": False} - } - def parse_t2i_response(response: requests.Response) -> bytes | str: - # Response for T2I is binary image data - if response.content: - return response.content # Return bytes - else: - return f"Erreur API Hugging Face (T2I): Réponse image vide ou inattendue." - response_parser = parse_t2i_response - response_is_json = False - + """ + Builds a prompt string suitable for Mistral-style instruction models using the + chat history (text messages only). + Assumes chat_history_for_prompt contains the messages relevant + to the model turns, starting with the first actual user turn (potentially including system) + and ending with the current user message. + """ + formatted_prompt = "" + system_message_handled = False + + for i in range(len(chat_history_for_prompt) - 1): # Iterate up to the second-to-last message + message = chat_history_for_prompt[i] + # Ensure messages are text type before including them in prompt building + if message.get("type", "text") != "text": continue # Skip non-text messages + + if message["role"] == "user": + # Only apply system message with the *first* user turn + if not system_message_handled and system_message: + formatted_prompt += f"[INST] <>\n{system_message}\n<>\n\n{message['content']} [/INST]" + system_message_handled = True + else: + formatted_prompt += f"[INST] {message['content']} [/INST]" + elif message["role"] == "assistant": + # Assistant messages should follow a user message + if i > 0 and chat_history_for_prompt[i-1]["role"] == "user": # Check to ensure previous was user + formatted_prompt += f" {message['content']}" + else: # Fallback for cases where sequence is unexpected, e.g., assistant first + # This case suggests an unusual history (e.g., assistant message not preceded by user) + # For Mistral, this might break the chat turn format. + formatted_prompt += f" {message['content']}" # This might be problematic for turn alignment + + # Handle the *current* user message (the very last one) + if chat_history_for_prompt: + current_user_msg = chat_history_for_prompt[-1] + if current_user_msg["role"] != "user" or current_user_msg.get("type", "text") != "text": + # Safety check: last message must be user & text type for this prompt builder + st.error("Logique d'historique invalide pour le prompt Mistral: La dernière message n'est pas de l'utilisateur ou n'est pas du texte.") + return "" + + if not system_message_handled and system_message: + formatted_prompt += f"[INST] <>\n{system_message}\n<>\n\n{current_user_msg['content']} [/INST]" else: - return f"Erreur interne: Type de modèle Hugging Face '{model_type}' inconnu." + formatted_prompt += f"[INST] {current_user_msg['content']} [/INST]" + else: + st.error("Logique d'historique invalide : Historique du prompt est vide.") + return "" - if response_parser is None: - return f"Erreur interne: Le type de modèle '{model_type}' n'a pas de parseur de réponse défini." + return formatted_prompt +def call_hf_inference(model_id: str, payload_inputs: any, params: dict, model_type: str) -> any: + """ + Calls the Hugging Face Inference API for either text generation or text-to-image. + payload_inputs is the main input (string for text, string for t2i prompt). + params depends on the model_type. + Returns string for text, bytes for image, or error string. + """ + if not HUGGINGFACEHUB_API_TOKEN: + return "Erreur d'API Hugging Face: Le token HUGGINGFACEHUB_API_TOKEN est introuvable." + + headers = {"Authorization": f"Bearer {HUGGINGFACEHUB_API_TOKEN}"} + url = f"{HF_BASE_API_URL}{model_id}" # Corrected URL to include model_id + + payload = {} + response_parser = None + + if model_type == 'text': + payload = { + "inputs": payload_inputs, # The prompt string built elsewhere + "parameters": { + "max_new_tokens": params.get("max_new_tokens", 200), + "temperature": params.get("temperature", 0.6), + "top_p": params.get("top_p", 0.9), + "return_full_text": False, + "num_return_sequences": 1, + "do_sample": params.get("temperature", 0.6) > 1e-2, + }, + "options": {"wait_for_model": True, "use_cache": False} + } + def parse_text_response(response_data): + if isinstance(response_data, list) and response_data and "generated_text" in response_data[0]: + return response_data[0]["generated_text"].strip() + else: + return f"Erreur API Hugging Face (Format): Réponse texte inattendue - {response_data}" + response_parser = parse_text_response + response_is_json = True # Expected response format + + elif model_type == 't2i': + payload = { + "inputs": payload_inputs, # The prompt string (user input) + "parameters": { + "negative_prompt": params.get("negative_prompt", ""), + "num_inference_steps": params.get("num_inference_steps", 50), + "guidance_scale": params.get("guidance_scale", 7.5), + "height": params.get("image_height", 512), + "width": params.get("image_width", 512), + }, + "options": {"wait_for_model": True, "use_cache": False} + } + def parse_t2i_response(response): + # Response for T2I is binary image data + if response.content: + return response.content # Return bytes + else: + return f"Erreur API Hugging Face (T2I): Réponse image vide ou inattendue." + response_parser = parse_t2i_response + response_is_json = False # Expected response format is binary/image + + else: + return f"Erreur interne: Type de modèle Hugging Face '{model_type}' inconnu." + + + if response_parser is None: # Should not happen if block above is complete + return f"Erreur interne: Le type de modèle '{model_type}' n'a pas de parseur de réponse défini." + + try: + # For T2I, `json` parameter in requests post is not used, `data` is used with the raw prompt bytes. + # For text, `json` is used. + if model_type == 'text': + response = requests.post(url, headers=headers, json=payload, timeout=300) + elif model_type == 't2i': + # `inputs` is the prompt string in payload, but for HF Inference API for T2I, + # sometimes the prompt is passed as raw `data` in the request body. + # However, the structure of payload for text-to-image is typically JSON as well. + # The FLUX.1-dev model's API expects JSON payload. + response = requests.post(url, headers=headers, json=payload, timeout=300) + + response.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx) + + # Check content type based on expected format + if response_is_json: + # Expecting JSON (e.g., text models) + if 'application/json' in response.headers.get('Content-Type', '').lower(): + response_data = response.json() + return response_parser(response_data) + else: + # Received non-JSON when JSON was expected + return f"Erreur API Hugging Face: Type de contenu inattendu ({response.headers.get('Content-Type', 'N/A')}) - attendu JSON." + else: # Expecting binary/image (e.g., t2i models) + if 'image/' in response.headers.get('Content-Type', '').lower() or 'binary' in response.headers.get('Content-Type', '').lower() or ('application/json' not in response.headers.get('Content-Type', '').lower() and response.content): + return response_parser(response) # Pass the full response object or just content if parser needs it + elif 'application/json' in response.headers.get('Content-Type', '').lower(): + # Sometimes errors come back as JSON even for binary endpoints + error_detail = response.json() + error_message = error_detail.get('error', str(error_detail)) + if isinstance(error_message, list): error_message = ". ".join(error_message) + estimated_time = error_detail.get('estimated_time', None) + if estimated_time: + return f"Erreur API Hugging Face ({response.status_code}): {error_message}. Le modèle est peut-être en chargement, veuillez réessayer dans environ {_format_time(estimated_time)}." + return f"Erreur API Hugging Face ({response.status_code}): {error_message}" + else: + # Received unexpected content type + return f"Erreur API Hugging Face: Type de contenu inattendu ({response.headers.get('Content-Type', 'N/A')}) - attendu image/binary." + + + except requests.exceptions.Timeout: + return "Erreur API Hugging Face: Délai d'attente dépassé pour la requête (300s). Le modèle est peut-être lent à charger. Veuillez réessayer." + except requests.exceptions.ConnectionError as e: + return f"Erreur API Hugging Face: Impossible de se connecter à l'API. Détails: {e}" + except requests.exceptions.HTTPError as e: + # Try to parse error details if available, otherwise display status code and text try: - response = requests.post(url, headers=headers, json=payload, timeout=300) - response.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx) - - # Check content type based on expected format - if response_is_json: - if 'application/json' in response.headers.get(CONTENT_KEY_TYPE, '').lower(): - response_data = response.json() - return response_parser(response_data) - else: - return f"Erreur API Hugging Face: Type de contenu inattendu ({response.headers.get(CONTENT_KEY_TYPE, 'N/A')}) - attendu JSON. Réponse brute: {response.text[:200]}..." - else: # Expecting binary/image - if 'image/' in response.headers.get(CONTENT_KEY_TYPE, '').lower() or 'binary' in response.headers.get(CONTENT_KEY_TYPE, '').lower() or (response.content and 'application/json' not in response.headers.get(CONTENT_KEY_TYPE, '').lower()): - return response_parser(response) - elif 'application/json' in response.headers.get(CONTENT_KEY_TYPE, '').lower(): - # Sometimes errors come back as JSON even for binary endpoints - error_detail = response.json() - error_message = error_detail.get('error', str(error_detail)) - if isinstance(error_message, list): error_message = ". ".join(error_message) - estimated_time = error_detail.get('estimated_time', None) - if estimated_time: - return f"Erreur API Hugging Face ({response.status_code}): {error_message}. Le modèle est peut-être en chargement, veuillez réessayer dans environ {_format_time(estimated_time)}." - return f"Erreur API Hugging Face ({response.status_code}): {error_message}" - else: - return f"Erreur API Hugging Face: Type de contenu inattendu ({response.headers.get(CONTENT_KEY_TYPE, 'N/A')}) - attendu image/binary. Réponse brute: {response.text[:200]}..." - - except requests.exceptions.Timeout: - return "Erreur API Hugging Face: Délai d'attente dépassé pour la requête (300s). Le modèle est peut-être lent à charger. Veuillez réessayer." - except requests.exceptions.ConnectionError as e: - return f"Erreur API Hugging Face: Impossible de se connecter à l'API. Détails: {e}" - except requests.exceptions.HTTPError as e: - try: - if 'application/json' in e.response.headers.get(CONTENT_KEY_TYPE, '').lower(): - error_detail = e.response.json() - error_message = error_detail.get('error', str(error_detail)) - if isinstance(error_message, list): error_message = ". ".join(error_message) - estimated_time = error_detail.get('estimated_time', None) - if estimated_time: - return f"Erreur API Hugging Face ({e.response.status_code}): {error_message}. Le modèle est peut-être en chargement, veuillez réessayer dans environ {_format_time(estimated_time)}." - return f"Erreur API Hugging Face ({e.response.status_code}): {error_message}" - else: - return f"Erreur API Hugging Face ({e.response.status_code}): {e.response.text}" - except Exception: - return f"Erreur API Hugging Face ({e.response.status_code}): Impossible d'obtenir les détails de l'erreur. Réponse brute: {e.response.text[:200]}..." - except Exception as e: - return f"Erreur inconnue lors de l'appel API Hugging Face: {e}" + if 'application/json' in e.response.headers.get('Content-Type', '').lower(): + error_detail = e.response.json() + error_message = error_detail.get('error', str(error_detail)) + if isinstance(error_message, list): error_message = ". ".join(error_message) + estimated_time = error_detail.get('estimated_time', None) + if estimated_time: + return f"Erreur API Hugging Face ({e.response.status_code}): {error_message}. Le modèle est peut-être en chargement, veuillez réessayer dans environ {_format_time(estimated_time)}." + return f"Erreur API Hugging Face ({e.response.status_code}): {error_message}" + else: + # Non-JSON error response + return f"Erreur API Hugging Face ({e.response.status_code}): {e.response.text}" + + except Exception: # Catch any error during JSON decoding or key access + return f"Erreur API Hugging Face ({e.response.status_code}): Impossible d'obtenir les détails de l'erreur. Réponse brute: {e.response.text[:200]}..." # Limit raw text length + except Exception as e: + return f"Erreur inconnue lors de l'appel API Hugging Face: {e}" def call_google_api(model_id: str, system_message: str, chat_history_for_api: list[dict], params: dict) -> str: - """Calls the Google Generative AI API (Text models only).""" - if not GOOGLE_API_KEY: - return "Erreur d'API Google: La clé GOOGLE_API_KEY est introuvable." + """Calls the Google Generative AI API (Text models only).""" + if not GOOGLE_API_KEY: + return "Erreur d'API Google: Le token GOOGLE_API_KEY est introuvable." + + try: + genai.configure(api_key=GOOGLE_API_KEY) + except Exception as e: + return f"Erreur de configuration Google API: {e}. Vérifiez votre GOOGLE_API_KEY." + + try: + # Pass model_id as the first positional argument. + model = genai.GenerativeModel(model_id) + + # Prepare history for the Gemini API + # The standard format is a list of dicts: [{'role': 'user', 'parts': [...]}, {'role': 'model', 'parts': [...]}, ...] + # `chat_history_for_api` here is the 'actual_conversation_history' list (filtered and ordered). + gemini_history_parts = [] + for msg in chat_history_for_api: + # Only include text messages in history sent to Google Text models + if msg.get("type", "text") == "text": + # Map roles: Streamlit 'user' -> Google 'user', Streamlit 'assistant' -> Google 'model' + role = 'user' if msg['role'] == 'user' else 'model' + gemini_history_parts.append({"role": role, "parts": [msg['content']]}) + # Note: Gemini can handle multimodal (`parts` can contain text/image/etc.), + # but this specific `call_google_api` function is currently designed + # for text-only history to ensure compatibility with text models. + # A separate `call_google_multimodal_api` would be needed for full multimodal. + + system_instruction_param = system_message if system_message else None + + generation_config = genai.types.GenerationConfig( + max_output_tokens=params.get("max_new_tokens", 200), + temperature=params.get("temperature", 0.6), + top_p=params.get("top_p", 0.9), + # top_k=params.get("top_k", None), + ) - try: - genai.configure(api_key=GOOGLE_API_KEY) - except Exception as e: - return f"Erreur de configuration Google API: {e}. Vérifiez votre GOOGLE_API_KEY." + # Pass the cleaned history and system instruction to generate_content + response = model.generate_content( + contents=gemini_history_parts, + generation_config=generation_config, + system_instruction=system_instruction_param, + request_options={'timeout': 180} # Increased timeout slightly + ) - try: - model = genai.GenerativeModel(model_id) - - # Prepare history for the Gemini API - gemini_history_parts = [] - for msg in chat_history_for_api: - # Only include text messages in history for Google Text models - if msg.get(TYPE_KEY, TYPE_TEXT) == TYPE_TEXT: - # Map roles: Streamlit 'user' -> Google 'user', Streamlit 'assistant' -> Google 'model' - role = ROLE_USER if msg[ROLE_KEY] == ROLE_USER else 'model' - gemini_history_parts.append({"role": role, "parts": [msg[CONTENT_KEY]]}) - - system_instruction_param = system_message if system_message else None - - generation_config = genai.types.GenerationConfig( - max_output_tokens=params.get("max_new_tokens", 200), - temperature=params.get("temperature", 0.6), - top_p=params.get("top_p", 0.9), - ) - - response = model.generate_content( - contents=gemini_history_parts, - generation_config=generation_config, - system_instruction=system_instruction_param, - request_options={'timeout': 180} - ) - - if response.candidates: - if response.candidates[0].content and response.candidates[0].content.parts: - generated_text = "".join(part.text for part in response.candidates[0].content.parts) - return generated_text.strip() - elif hasattr(response.candidates[0], 'finish_reason') and response.candidates[0].finish_reason: - return f"API Google: La génération s'est terminée prématurément. Raison: {response.candidates[0].finish_reason}. Prompt feedback: {response.prompt_feedback}" - else: - return f"Erreur API Google: Réponse vide ou inattendue après appel réussi. Debug info: {response.candidates[0]}" - elif hasattr(response, 'prompt_feedback') and response.prompt_feedback and response.prompt_feedback.block_reason: - return f"Erreur API Google: Votre message a été bloqué ({response.prompt_feedback.block_reason.name}). Raison détaillée: {response.prompt_feedback.safety_ratings}" - else: - return f"Erreur API Google: Aucune réponse générée pour une raison inconnue. Debug info: {response}" - - except Exception as e: - return f"Erreur lors de l'appel API Google: {e}" - -def _format_time(seconds: float) -> str: - """Formats seconds into a human-readable string (e.g., '2 min 30 sec').""" - if not isinstance(seconds, (int, float)) or seconds < 0: return "N/A" - minutes = int(seconds // 60) - remaining_seconds = int(seconds % 60) - if minutes > 0: return f"{minutes} min {remaining_seconds} sec" - return f"{remaining_seconds} sec" # If less than a minute, just show seconds + # Process the response + if response.candidates: + if response.candidates[0].content and response.candidates[0].content.parts: + generated_text = "".join(part.text for part in response.candidates[0].content.parts) + return generated_text.strip() + elif hasattr(response.candidates[0], 'finish_reason') and response.candidates[0].finish_reason: + return f"API Google: La génération s'est terminée prématurément. Raison: {response.candidates[0].finish_reason}. Prompt feedback: {response.prompt_feedback}" + else: + return f"Erreur API Google: Réponse vide ou inattendue après appel réussi. Debug info: {response.candidates[0]}" + elif hasattr(response, 'prompt_feedback') and response.prompt_feedback and response.prompt_feedback.block_reason: + return f"Erreur API Google: Votre message a été bloqué ({response.prompt_feedback.block_reason}). Raison détaillée: {response.prompt_feedback.safety_ratings}" + else: + return f"Erreur API Google: Aucune réponse générée pour une raison inconnue. Debug info: {response}" + + except Exception as e: + return f"Erreur lors de l'appel API Google: {e}" + +def _format_time(seconds): + """Formats seconds into a human-readable string.""" + if not isinstance(seconds, (int, float)) or seconds < 0: return "N/A" + minutes = int(seconds // 60) + remaining_seconds = int(seconds % 60) + if minutes > 0: return f"{minutes} min {remaining_seconds} sec" # Corrected string formatting + return f"{remaining_seconds} sec" # ----------------------------------------------------------------------------- # Generation Functions (Separated by Task Type) # ----------------------------------------------------------------------------- -def get_text_response(selected_model_id: str, system_prompt: str, full_chat_history: list[dict]) -> dict: - """ - Handles text generation request using the selected text model. - Returns a dictionary representing the assistant's response. - """ - selected_model_info = next((m for m in MODELS_BY_TYPE.get(TASK_TEXT, []) if m['id'] == selected_model_id), None) - if not selected_model_info: - return {ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: f"Erreur interne: Modèle texte '{selected_model_id}' introuvable ou n'est pas un modèle texte.", TYPE_KEY: TYPE_TEXT} - - model_provider = selected_model_info['provider'] - model_type = selected_model_info['type'] # Should be TASK_TEXT - params = { # Use parameters from session state for text generation - "max_new_tokens": st.session_state.max_response_length, - "temperature": st.session_state.temperature, - "top_p": st.session_state.top_p, - } - - # Create a list of actual conversation turns relevant for text models. - # Excludes the initial assistant starter message if it's the only one and there's more history, - # and only includes actual text messages. - actual_conversation_history = [ - msg for msg in full_chat_history - if msg.get(TYPE_KEY, TYPE_TEXT) == TYPE_TEXT # Only include text messages for prompt building - and not (msg[ROLE_KEY] == ROLE_ASSISTANT and msg[CONTENT_KEY] == st.session_state.starter_message and len(full_chat_history) > 1 and full_chat_history[0][ROLE_KEY] == ROLE_ASSISTANT) - ] - - # Safety check: the last message should be from the user before generating response - if not actual_conversation_history or actual_conversation_history[-1][ROLE_KEY] != ROLE_USER: - return {ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: "Erreur interne: Logique d'historique de conversation invalide pour la génération de texte (dernière message non-utilisateur).", TYPE_KEY: TYPE_TEXT} - - response_content = "" - if model_provider == PROVIDER_HF: - llm_prompt = build_mistral_prompt(system_prompt, actual_conversation_history) - if not llm_prompt: - return {ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: "Erreur lors de la construction du prompt pour Hugging Face (texte).", TYPE_KEY: TYPE_TEXT} - response_content = call_hf_inference(selected_model_id, llm_prompt, params, model_type) - - elif model_provider == PROVIDER_GOOGLE: - response_content = call_google_api(selected_model_id, system_prompt, actual_conversation_history, params) +def get_text_response(selected_model_id: str, user_text: str, system_prompt: str, full_chat_history: list[dict]): + """ + Handles text generation request using the selected text model. + """ + selected_model_info = next((m for m in MODELS_BY_TYPE.get('text', []) if m['id'] == selected_model_id), None) + if not selected_model_info: + return f"Erreur interne: Modèle texte '{selected_model_id}' introuvable ou n'est pas un modèle texte." + + model_provider = selected_model_info['provider'] + model_id = selected_model_info['id'] + model_type = selected_model_info['type'] # Should be 'text' + params = { # Use parameters from session state for text generation + "max_new_tokens": st.session_state.max_response_length, + "temperature": st.session_state.temperature, + "top_p": st.session_state.top_p, + } + + # Create a list of actual conversation turns relevant for text models. + # This excludes the initial assistant starter message if it's just the default one + # and there's more history. It should only include text messages. + actual_conversation_history = [ + msg for msg in full_chat_history + if msg.get("type", "text") == "text" # Only include text messages in history for prompt building + and not (msg['role'] == 'assistant' and msg['content'] == st.session_state.starter_message and len(full_chat_history) > 1) + ] + + if not actual_conversation_history or actual_conversation_history[-1]['role'] != 'user': + st.error("Logique d'historique interne invalide pour la génération de texte.") + return {"role": "assistant", "content": "Erreur interne: Logique d'historique de conversation invalide pour la génération de texte.", "type": "text"} + + + if model_provider == 'huggingface': + # HF text requires prompt string built from history + llm_prompt = build_mistral_prompt(system_prompt, actual_conversation_history) + if not llm_prompt: return {"role": "assistant", "content": "Erreur lors de la construction du prompt pour Hugging Face (texte).", "type": "text"} + response_content = call_hf_inference(model_id, llm_prompt, params, model_type) + + elif model_provider == 'google': + # Google text requires list of messages. `call_google_api` handles internal formatting. + response_content = call_google_api(model_id, system_prompt, actual_conversation_history, params) + + else: + response_content = f"Erreur interne: Fournisseur API '{model_provider}' inconnu pour le modèle texte '{model_id}'." + + return {"role": "assistant", "content": response_content, "type": "text"} + +def get_image_response(selected_model_id: str, user_prompt: str): + """ + Handles image generation request using the selected T2I model. + """ + selected_model_info = next((m for m in MODELS_BY_TYPE.get('t2i', []) if m['id'] == selected_model_id), None) + if not selected_model_info: + return {"role": "assistant", "content": f"Erreur interne: Modèle image '{selected_model_id}' introuvable ou n'est pas un modèle image.", "type": "text"} # Return error as text + + model_provider = selected_model_info['provider'] + model_id = selected_model_info['id'] + model_type = selected_model_info['type'] # Should be 't2i' + params = { # Use parameters from session state for image generation + "num_inference_steps": st.session_state.num_inference_steps, + "guidance_scale": st.session_state.guidance_scale, + "image_height": st.session_state.image_height, + "image_width": st.session_state.image_width, + "negative_prompt": st.session_state.negative_prompt, + } + + # For T2I, the input is typically just the user prompt string. History is generally not used. + prompt_for_api = user_prompt + + if model_provider == 'huggingface': + response_content = call_hf_inference(model_id, prompt_for_api, params, model_type) + # call_hf_inference for t2i is expected to return bytes or an error string + if isinstance(response_content, bytes): + return {"role": "assistant", "image_data": response_content, "type": "t2i", "prompt": user_prompt} # Store image bytes and original prompt else: - response_content = f"Erreur interne: Fournisseur API '{model_provider}' inconnu pour le modèle texte '{selected_model_id}'." - - # Return the response encapsulated in a message dict - return {ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: response_content, TYPE_KEY: TYPE_TEXT} - -def get_image_response(selected_model_id: str, user_prompt: str) -> dict: - """ - Handles image generation request using the selected Text-to-Image model. - Returns a dictionary representing the assistant's response (image data or error text). - """ - selected_model_info = next((m for m in MODELS_BY_TYPE.get(TASK_IMAGE, []) if m['id'] == selected_model_id), None) - if not selected_model_info: - return {ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: f"Erreur interne: Modèle image '{selected_model_id}' introuvable ou n'est pas un modèle image.", TYPE_KEY: TYPE_TEXT} - - model_provider = selected_model_info['provider'] - model_type = selected_model_info['type'] # Should be TASK_IMAGE - params = { # Use parameters from session state for image generation - "num_inference_steps": st.session_state.num_inference_steps, - "guidance_scale": st.session_state.guidance_scale, - "image_height": st.session_state.image_height, - "image_width": st.session_state.image_width, - "negative_prompt": st.session_state.negative_prompt, - } - - # For T2I, the input is typically just the user prompt string. - prompt_for_api = user_prompt + # It's an error message string + return {"role": "assistant", "content": response_content, "type": "text", "prompt": user_prompt} # Store error as text message, keep prompt reference - assistant_response_content = None - if model_provider == PROVIDER_HF: - assistant_response_content = call_hf_inference(selected_model_id, prompt_for_api, params, model_type) - if isinstance(assistant_response_content, bytes): - return {ROLE_KEY: ROLE_ASSISTANT, IMAGE_DATA_KEY: assistant_response_content, TYPE_KEY: TYPE_T2I, PROMPT_KEY: user_prompt} - else: # It's an error message string - return {ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: assistant_response_content, TYPE_KEY: TYPE_TEXT, PROMPT_KEY: user_prompt} + elif model_provider == 'google': + # Implement Google T2I call if available in the future + return {"role": "assistant", "content": f"Erreur: Les modèles image de Google ne sont pas encore implémentés.", "type": "text", "prompt": user_prompt} - elif model_provider == PROVIDER_GOOGLE: - return {ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: f"Erreur: Les modèles image de Google ne sont pas encore implémentés.", TYPE_KEY: TYPE_TEXT, PROMPT_KEY: user_prompt} - - else: - return {ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: f"Erreur interne: Fournisseur API '{model_provider}' inconnu pour le modèle image '{selected_model_id}'.", TYPE_KEY: TYPE_TEXT, PROMPT_KEY: user_prompt} + else: + return {"role": "assistant", "content": f"Erreur interne: Fournisseur API '{model_provider}' inconnu pour le modèle image '{model_id}'." , "type": "text", "prompt": user_prompt} # ----------------------------------------------------------------------------- # Streamlit page configuration # ----------------------------------------------------------------------------- st.set_page_config(page_title="KolaChatBot - Multi-Modèles IA", page_icon="🤖", layout="wide") st.title("🤖 KolaChatBot - Multi-Modèles IA") - +# Dynamically display the currently selected model and task in the description selected_model_info = next((m for m in AVAILABLE_MODELS if m['id'] == st.session_state.get('selected_model_id', DEFAULT_MODEL_ID)), None) selected_model_name = selected_model_info['name'] if selected_model_info else st.session_state.selected_model_id # Fallback if model not found -current_task_label = "Génération de Texte" if st.session_state.selected_task == TASK_TEXT else "Génération d'Image (T2I)" -st.markdown(f"*Tâche actuelle : **{current_task_label}***\n*Modèle : **{selected_model_name}** *") - -CONTENT_KEY_TYPE = 'Content-Type' # Define constant for content type header key +current_task_label = "Génération de Texte" if st.session_state.selected_task == 'text' else "Génération d'Image (T2I)" +st.markdown(f"*Tâche actuelle : **{current_task_label}***\n*Modèle : **`{selected_model_name}`** *") # ----------------------------------------------------------------------------- # Manuel d'utilisation (Update with task separation) # ----------------------------------------------------------------------------- with st.expander("📖 Manuel d'utilisation de KolaChatBot", expanded=False): - st.markdown(""" - Bienvenue sur KolaChatBot - Une application de chat IA multi-modèles ! Voici comment tirer le meilleur parti de notre assistant IA : - - **1. Comment interagir ?** - - **Choisir la tâche :** Utilisez le sélecteur "Choisir la tâche IA" dans la barre latérale pour choisir entre "Texte" (conversation) et "Image" (text-to-image). - - **Entrer votre prompt :** Tapez votre message ou votre description d'image dans la zone de texte en bas et appuyez sur Entrée. - - **2. Paramètres dans la barre latérale (Sidebar) :** - La barre latérale adapte ses options au fur et à mesure que vous sélectionnez une tâche. - - * **Choisir la tâche IA :** Le contrôle principal pour basculer entre Texte et Image. - * **Sélection du Modèle :** Une liste filtrée des modèles disponibles pour la tâche choisie (Hugging Face ou Google). - - **Important :** Assurez-vous que les tokens API (`HUGGINGFACEHUB_API_TOKEN` ou `GOOGLE_API_KEY`) nécessaires au fournisseur du modèle sélectionné sont configurés dans votre fichier `.env` ou comme secrets dans votre espace Hugging Face/Streamlit. - - Changer de tâche ou de modèle **efface** la conversation/historique actuelle pour assurer la compatibilité. - * **Paramètres spécifiques à la tâche :** - * **Pour la tâche "Texte" :** Message Système / Personnalité, Message de Bienvenue IA, Max New Tokens, Temperature, Top-P. - * **Pour la tâche "Image" (T2I) :** Paramètres de génération d'image (Nombre d'étapes, Échelle de guidage, Dimensions, Prompt Négatif). - - **Attention :** Modifier les paramètres spécifiques à la tâche **efface** également l'historique/conversation. - * **Sélection des Avatars.** - * **Gestion de la Conversation :** Boutons pour appliquer les paramètres actuels (et démarrer une nouvelle conversation) ou simplement effacer l'historique actuel. - * **Exporter la Conversation (texte et prompt image uniquement) :** Téléchargez l'historique textuel et les prompts d'image. Les images générées ne sont pas incluses dans les formats d'exportation textuels. - - **3. Limitations :** - - Cette application ne stocke pas les images de manière permanente. Elles sont présentes dans l'historique de session tant que l'application tourne. - - Seuls les modèles textuels génèrent du texte, et seuls les modèles T2I génèrent des images (pour le moment, seul un modèle T2I HF est implémenté). - - Amusez-vous bien avec KolaChatBot ! - """) + st.markdown(""" + Bienvenue sur KolaChatBot - Une application de chat IA multi-modèles ! Voici comment tirer le meilleur parti de notre assistant IA : + + **1. Comment interagir ?** + - **Choisir la tâche :** Utilisez le sélecteur "Choisir la tâche IA" dans la barre latérale pour choisir entre "Texte" (conversation) et "Image" (text-to-image). + - **Entrer votre prompt :** Tapez votre message ou votre description d'image dans la zone de texte en bas et appuyez sur Entrée. + + **2. Paramètres dans la barre latérale (Sidebar) :** + La barre latérale adapte ses options au fur et à mesure que vous sélectionnez une tâche. + + * **Choisir la tâche IA :** Le contrôle principal pour basculer entre Texte et Image. + * **Sélection du Modèle :** Une liste filtrée des modèles disponibles pour la tâche choisie (Hugging Face ou Google). + - **Important :** Assurez-vous que les tokens API (`HUGGINGFACEHUB_API_TOKEN` ou `GOOGLE_API_KEY`) nécessaires au fournisseur du modèle sélectionné sont configurés dans votre fichier `.env` ou comme secrets dans votre espace Hugging Face. + - Changer de tâche ou de modèle **efface** la conversation/historique actuelle pour assurer la compatibilité. + * **Paramètres spécifiques à la tâche :** + * **Pour la tâche "Texte" :** Message Système / Personnalité, Message de Bienvenue IA, Max New Tokens, Temperature, Top-P. + * **Pour la tâche "Image" (T2I) :** Paramètres de génération d'image (Nombre d'étapes, Échelle de guidage, Dimensions, Prompt Négatif). + - **Attention :** Modifier les paramètres spécifiques à la tâche **efface** également l'historique/conversation. + * **Sélection des Avatars.** + * **Gestion de la Conversation :** Boutons pour appliquer les paramètres actuels (et démarrer une nouvelle conversation) ou simplement effacer l'historique actuel. + * **Exporter la Conversation (texte uniquement) :** Téléchargez l'historique textuel. Les images générées ne sont pas incluses dans les formats d'exportation textuels. + + **3. Limitations :** + - Cette application ne stocke pas les images de manière permanente. Elles sont présentes dans l'historique de session tant que l'application tourne. + - Seuls les modèles textuels génèrent du texte, et seuls les modèles T2I génèrent des images (pour le moment, seul un modèle T2I HF est implémenté). + + Amusez-vous bien avec KolaChatBot ! + """) # ----------------------------------------------------------------------------- # Sidebar settings (with task, model, and param separation) # ----------------------------------------------------------------------------- with st.sidebar: - st.header("🛠️ Configuration de KolaChatBot") - - st.subheader("🎯 Choisir la tâche IA") - selected_task_from_radio = st.radio( - "Sélectionnez l'action :", - options=['Texte', 'Image'], - index=0 if st.session_state.selected_task == TASK_TEXT else 1, - key="task_radio_button", - help="Choisissez si vous voulez générer du texte (conversation) ou une image (text-to-image)." + st.header("🛠️ Configuration de KolaChatBot") + + st.subheader("🎯 Choisir la tâche IA") + selected_task_from_radio = st.radio( + "Sélectionnez l'action :", + options=['Texte', 'Image'], + index=0 if st.session_state.selected_task == 'text' else 1, + key="task_radio_button", + help="Choisissez si vous voulez générer du texte (conversation) ou une image (text-to-image)." + ) + + # Update session state task if radio button changed + task_changed = False + if selected_task_from_radio.lower() != st.session_state.selected_task: + st.session_state.selected_task = selected_task_from_radio.lower() # Ensure 'text' or 'image' + task_changed = True + # Automatic reset of history when task changes + st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message, "type": "text"}] + st.info(f"⚠️ Tâche changée en '{st.session_state.selected_task}'. La conversation a été réinitialisée.") + # No need for explicit rerun here, the history change will trigger one. + + st.subheader(f"🧠 Modèle pour la tâche '{st.session_state.selected_task}'") + # Determine models available for the current task + current_task_models = MODELS_BY_TYPE.get(st.session_state.selected_task, []) + + selected_model_id_from_selectbox = None # Initialize to None + if not current_task_models: + st.warning(f"Aucun modèle disponible pour la tâche '{st.session_state.selected_task}'.") + st.session_state.selected_model_id = None # Ensure no model is selected + else: + model_options = {model['id']: model['name'] for model in current_task_models} + # Find the index of the currently selected model within the filtered list + # If the current model ID is not in the filtered list (e.g., switched task), default to the first model + current_model_index = 0 + if st.session_state.selected_model_id in model_options: + current_model_index = list(model_options.keys()).index(st.session_state.selected_model_id) + # If the current model ID is NOT in the new list of options, default to the first available model + elif current_task_models: # if there are models available for this task + st.session_state.selected_model_id = current_task_models[0]['id'] # Update session state to first model for new task + current_model_index = 0 # Reset index + else: # Should already be covered by 'if not current_task_models' + st.session_state.selected_model_id = None + + selected_model_id_from_selectbox = st.selectbox( + "Choisir le modèle :", + options=list(model_options.keys()), + format_func=lambda x: model_options[x], + index=current_model_index, + key="model_select_box", + help=f"Sélectionnez un modèle {st.session_state.selected_task.lower()}." ) - # Update session state task if radio button changed - task_changed = False - if selected_task_from_radio.lower() != st.session_state.selected_task: - st.session_state.selected_task = selected_task_from_radio.lower() # Ensure 'text' or 'image' - task_changed = True - # Automatic reset of history when task changes - st.session_state.chat_history = [{ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: st.session_state.starter_message, TYPE_KEY: TYPE_TEXT}] - st.info(f"⚠️ Tâche changée en '{st.session_state.selected_task}'. La conversation a été réinitialisée.") - st.session_state._prev_task = st.session_state.selected_task # Update previous task for next rerun - - st.subheader(f"🧠 Modèle pour la tâche '{st.session_state.selected_task}'") - # Determine models available for the current task - current_task_models = MODELS_BY_TYPE.get(st.session_state.selected_task, []) - - selected_model_id_from_selectbox = None - if not current_task_models: - st.warning(f"Aucun modèle disponible pour la tâche '{st.session_state.selected_task}'.") - st.session_state.selected_model_id = None # Ensure no model is selected - else: - model_options_dict = {model['id']: model['name'] for model in current_task_models} - - # Determine the initial selection index for the selectbox - current_model_index = 0 - if st.session_state.selected_model_id in model_options_dict: - current_model_index = list(model_options_dict.keys()).index(st.session_state.selected_model_id) - else: # If the previously selected model is not in the list for the new task - st.session_state.selected_model_id = current_task_models[0]['id'] # Default to the first model for the new task - current_model_index = 0 - - selected_model_id_from_selectbox = st.selectbox( - "Choisir le modèle :", - options=list(model_options_dict.keys()), - format_func=lambda x: model_options_dict[x], - index=current_model_index, - key="model_select_box", - help=f"Sélectionnez un modèle {st.session_state.selected_task.lower()}." - ) - - # Handle model selection change - model_changed = False - if selected_model_id_from_selectbox is not None and selected_model_id_from_selectbox != st.session_state.selected_model_id: - st.session_state.selected_model_id = selected_model_id_from_selectbox - model_changed = True - st.session_state.chat_history = [{ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: st.session_state.starter_message, TYPE_KEY: TYPE_TEXT}] - st.info(f"⚠️ Modèle changé en '{selected_model_id_from_selectbox}'. La conversation a été réinitialisée.") - st.session_state._prev_model_id = st.session_state.selected_model_id - - # If model (or task) changed, update parameters in session state to reflect the *new* model's defaults. - if model_changed or task_changed: - new_model_info = next((m for m in AVAILABLE_MODELS if m['id'] == st.session_state.selected_model_id), None) - if new_model_info: - new_model_params = new_model_info['params'] - if new_model_info['type'] == TASK_TEXT: - st.session_state.max_response_length = new_model_params.get("max_new_tokens", 200) - st.session_state.temperature = new_model_params.get("temperature", 0.6) - st.session_state.top_p = new_model_params.get("top_p", 0.9) - # Reset T2I specific params - st.session_state.num_inference_steps = None - st.session_state.guidance_scale = None - st.session_state.image_height = None - st.session_state.image_width = None - st.session_state.negative_prompt = "" - elif new_model_info['type'] == TASK_IMAGE: - st.session_state.num_inference_steps = new_model_params.get("num_inference_steps", 50) - st.session_state.guidance_scale = new_model_params.get("guidance_scale", 7.5) - st.session_state.image_height = new_model_params.get("height", 512) - st.session_state.image_width = new_model_params.get("width", 512) - st.session_state.negative_prompt = new_model_params.get("negative_prompt", "") - # Reset Text params - st.session_state.max_response_length = None - st.session_state.temperature = None - st.session_state.top_p = None - st.rerun() # Rerun immediately if task or model changed to update UI and parameters - - # Check required API Key for the selected model - if current_task_models and selected_model_id_from_selectbox: # Only check if a model is actually selected - current_model_info_check = next((m for m in AVAILABLE_MODELS if m['id'] == selected_model_id_from_selectbox), None) - if current_model_info_check: - current_model_provider = current_model_info_check['provider'] - if current_model_provider == PROVIDER_HF: - if not HUGGINGFACEHUB_API_TOKEN: - st.warning("❌ Le token Hugging Face est manquant (`HUGGINGFACEHUB_API_TOKEN`). Les modèles Hugging Face ne fonctionneront pas.") - elif current_model_provider == PROVIDER_GOOGLE: - if not GOOGLE_API_KEY: - st.warning("❌ La clé API Google est manquante (`GOOGLE_API_KEY`). Les modèles Google ne fonctionneront pas.") - else: - st.warning(f"Modèle sélectionné '{selected_model_id_from_selectbox}' introuvable dans la liste disponible. Veuillez sélectionner un modèle valide.") - - - # --- Dynamic Parameter Settings based on Task --- - st.subheader("⚙️ Paramètres") - if st.session_state.selected_task == TASK_TEXT: - with st.expander("Ajuster les Paramètres Texte", expanded=True): - st.session_state.max_response_length = st.number_input( - "Max New Tokens (longueur max réponse) :", - min_value=20, max_value=8192, - value=st.session_state.max_response_length if st.session_state.max_response_length is not None else 200, step=10, - key="max_new_tokens_input", - help="Longueur maximale de la réponse de l'IA (en jetons ou tokens).", - ) - st.session_state.temperature = st.slider( - "Temperature (créativité) :", - min_value=0.0, max_value=2.0, - value=st.session_state.temperature if st.session_state.temperature is not None else 0.6, step=0.01, - key="temperature_input", - help="Contrôle le caractère aléatoire des réponses. Plus élevé = plus créatif/imprévisible.", - ) - st.session_state.top_p = st.slider( - "Top-P (sampling) :", - min_value=0.01, max_value=1.0, - value=st.session_state.top_p if st.session_state.top_p is not None else 0.9, step=0.01, - key="top_p_input", - help="Contrôle la diversité en limitant les options de tokens. Plus bas = moins diversifié. 1.0 = désactivé.", - ) - - # Handle System and Starter messages here for text task - st.session_state.system_message = st.text_area( - "Message Système / Personnalité :", - value=st.session_state.system_message, - height=100, - key="system_message_input", - help="Décrivez le rôle ou le style que l'IA de texte doit adopter. Sa capacité à suivre cette consigne dépend du modèle.", - ) - st.session_state.starter_message = st.text_area( - "Message de Bienvenue de l'IA :", - value=st.session_state.starter_message, - height=100, - key="starter_message_input", - help="Le premier message que l'IA affichera au début d'une nouvelle conversation textuelle ou après un reset.", - ) - - elif st.session_state.selected_task == TASK_IMAGE: - with st.expander("Ajuster les Paramètres Image (T2I)", expanded=True): - st.session_state.num_inference_steps = st.number_input( - "Nombre d'étapes d'inférence :", - min_value=1, max_value=200, - value=st.session_state.num_inference_steps if st.session_state.num_inference_steps is not None else 50, step=5, - key="num_inference_steps_input", - help="Nombre d'étapes dans le processus de diffusion. Plus élevé = potentiellement meilleure qualité mais plus lent." - ) - st.session_state.guidance_scale = st.number_input( - "Échelle de guidage (CFG) :", - min_value=0.0, max_value=20.0, - value=st.session_state.guidance_scale if st.session_state.guidance_scale is not None else 7.5, step=0.5, - key="guidance_scale_input", - help="Dans quelle mesure l'IA doit suivre le prompt. Plus élevé = plus fidèle au prompt mais potentiellement moins créatif." - ) - st.session_state.image_height = st.number_input( - "Hauteur de l'image :", - min_value=128, max_value=1024, - value=st.session_state.image_height if st.session_state.image_height is not None else 512, step=64, - key="image_height_input", - help="Hauteur en pixels de l'image générée." - ) - st.session_state.image_width = st.number_input( - "Largeur de l'image :", - min_value=128, max_value=1024, - value=st.session_state.image_width if st.session_state.image_width is not None else 512, step=64, - key="image_width_input", - help="Largeur en pixels de l'image générée." - ) - st.session_state.negative_prompt = st.text_area( - "Prompt Négatif :", - value=st.session_state.negative_prompt, - height=100, - key="negative_prompt_input", - help="Ce que vous NE voulez PAS voir dans l'image." - ) - # UI for Avatars - st.subheader("👤 Interface Utilisateur") - with st.expander("Choisir les Avatars Cyberpunk", expanded=False): - st.markdown("*Sélection des avatars personnalisés :*") - col1_avatar, col2_avatar = st.columns(2) - with col1_avatar: - st.session_state.avatars[ROLE_ASSISTANT] = st.selectbox("Avatar IA", options=["🤖", "🎨", "✨", "💡", "🌟", "👩🏽‍⚕️", "👨🏿‍🎓", "⚙️"], index=0, key="avatar_ia_select") - with col2_avatar: - st.session_state.avatars[ROLE_USER] = st.selectbox("Avatar Utilisateur", options=["👤", "👩‍💻", "👨‍🎓", "❓", "💡", "🧑‍🔧", "👩🏽‍🔬", "🕵🏽", "🧑‍🚀"], index=0, key="avatar_user_select") - - # --- Conversation Management --- - st.subheader("🔄 Gestion de la Conversation") - # Reset button that applies *all* parameters (system/starter/model) and starts new conversation - if st.button("♻️ Appliquer Paramètres & Nouvelle Conversation", type="primary", help=f"Applique les paramètres actuels pour la tâche '{st.session_state.selected_task}' et démarre une nouvelle conversation ({'histoire textuelle' if st.session_state.selected_task == TASK_TEXT else 'génération d\'images'}) en effaçant l'historique."): - st.session_state.chat_history = [{ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: st.session_state.starter_message, TYPE_KEY: TYPE_TEXT}] - st.info("Paramètres appliqués et conversation réinitialisée.") - st.session_state._reset_triggered = True # Set flag to prevent immediate re-processing - st.rerun() # Rerun immediately to refresh UI - - # Clear history button - simpler, just history - if st.button("🗑️ Effacer la Conversation Actuelle", help=f"Efface l'historique de conversation pour la tâche '{st.session_state.selected_task}' mais conserve les paramètres actuels."): - st.session_state.chat_history = [{ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: st.session_state.starter_message, TYPE_KEY: TYPE_TEXT}] - st.info("Conversation actuelle effacée.") - st.session_state._reset_triggered = True # Set flag to prevent immediate re-processing - st.rerun() # Rerun immediately to show empty chat with starter message - - st.subheader("📥 Exporter la Conversation") - # Check if there's actual history (more than just the initial assistant message) - export_possible = len([msg for msg in st.session_state.chat_history if msg.get(ROLE_KEY) != ROLE_SYSTEM]) > 1 - - if "chat_history" in st.session_state and export_possible : - # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Not used directly in file names here as Streamlit handles it - - col_export1, col_export2, col_export3 = st.columns(3) - with col_export1: - st.download_button( - label="TXT", - data=format_history_to_txt(st.session_state.chat_history), - file_name=f"kolachat_{st.session_state.selected_task}_conversation.txt", - mime="text/plain", - use_container_width=True - ) - with col_export2: - st.download_button( - label="JSON", - data=format_history_to_json(st.session_state.chat_history), - file_name=f"kolachat_{st.session_state.selected_task}_conversation.json", - mime="application/json", - use_container_width=True - ) - with col_export3: - st.download_button( - label="Markdown", - data=format_history_to_md(st.session_state.chat_history), - file_name=f"kolachat_{st.session_state.selected_task}_conversation.md", - mime="text/markdown", - use_container_width=True - ) + # Handle model selection change + model_changed = False + if selected_model_id_from_selectbox is not None and selected_model_id_from_selectbox != st.session_state.selected_model_id: + st.session_state.selected_model_id = selected_model_id_from_selectbox + model_changed = True + # Automatic reset of history when model changes + st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message, "type": "text"}] + st.info(f"⚠️ Modèle changé en '{st.session_state.selected_model_id}'. La conversation a été réinitialisée.") + st.session_state._prev_model_id = st.session_state.selected_model_id # Store new model ID + + # If model changed (either directly or because task changed), update parameters in session state + # to reflect the *new* model's defaults. + if model_changed or task_changed: + # Find info for the newly selected model + new_model_info = next((m for m in AVAILABLE_MODELS if m['id'] == st.session_state.selected_model_id), None) + if new_model_info: + # Update parameters in session state with the new model's defaults + # This ensures the input widgets are populated correctly on the next rerun + # and the API call uses the right defaults before user interaction. + if new_model_info['type'] == 'text': + st.session_state.max_response_length = new_model_info['params'].get("max_new_tokens", 200) + st.session_state.temperature = new_model_info['params'].get("temperature", 0.6) + st.session_state.top_p = new_model_info['params'].get("top_p", 0.9) + # Reset T2I specific params (visually these inputs are hidden anyway) + st.session_state.num_inference_steps = None + st.session_state.guidance_scale = None + st.session_state.image_height = None + st.session_state.image_width = None + st.session_state.negative_prompt = "" + + elif new_model_info['type'] == 't2i': + st.session_state.num_inference_steps = new_model_info['params'].get("num_inference_steps", 50) + st.session_state.guidance_scale = new_model_info['params'].get("guidance_scale", 7.5) + st.session_state.image_height = new_model_info['params'].get("height", 512) + st.session_state.image_width = new_model_info['params'].get("width", 512) + st.session_state.negative_prompt = new_model_info['params'].get("negative_prompt", "") + # Reset Text specific params + st.session_state.max_response_length = None + st.session_state.temperature = None + st.session_state.top_p = None + + st.session_state._prev_model_id = st.session_state.selected_model_id # Update previous model ID + if task_changed or model_changed: + st.rerun() # Rerun immediately if task or model changed to update UI + + # Check required API Key for the selected model + if current_task_models and selected_model_id_from_selectbox: # Only check if a model is actually selected + current_model_info_check = next((m for m in AVAILABLE_MODELS if m['id'] == selected_model_id_from_selectbox), None) + if current_model_info_check: + current_model_provider = current_model_info_check['provider'] + if current_model_provider == 'huggingface': + if not HUGGINGFACEHUB_API_TOKEN: + st.warning("❌ Le token Hugging Face est manquant (`HUGGINGFACEHUB_API_TOKEN`). Les modèles Hugging Face ne fonctionneront pas.") + elif current_model_provider == 'google': + if not GOOGLE_API_KEY: + st.warning("❌ La clé API Google est manquante (`GOOGLE_API_KEY`). Les modèles Google ne fonctionneront pas.") else: - st.caption("Aucune conversation substantielle ou exportable pour le moment.") - - st.divider() - st.subheader("Crédits") - st.markdown(""" - **Auteur :** Sidoine K. YEBADOKPO - *Expert en Analyse de Données* + st.warning(f"Modèle sélectionné '{selected_model_id_from_selectbox}' introuvable dans la liste disponible. Veuillez sélectionner un modèle valide.") + + # --- Dynamic Parameter Settings based on Task --- + st.subheader("⚙️ Paramètres") + if st.session_state.selected_task == 'text': + with st.expander("Ajuster les Paramètres Texte", expanded=True): + # Link input widgets directly to session state parameters + st.session_state.max_response_length = st.number_input( + "Max New Tokens (longueur max réponse) :", + min_value=20, max_value=8192, + value=st.session_state.max_response_length if st.session_state.max_response_length is not None else 200, step=10, + key="max_new_tokens_input", + help="Longueur maximale de la réponse de l'IA (en jetons ou tokens).", + # on_change=lambda: setattr(st.session_state, '_params_changed', True) # Set flag on change? More complex. + ) + st.session_state.temperature = st.slider( + "Temperature (créativité) :", + min_value=0.0, max_value=2.0, + value=st.session_state.temperature if st.session_state.temperature is not None else 0.6, step=0.01, + key="temperature_input", + help="Contrôle le caractère aléatoire des réponses. Plus élevé = plus créatif/imprévisible.", + # on_change=lambda: setattr(st.session_state, '_params_changed', True) + ) + st.session_state.top_p = st.slider( + "Top-P (sampling) :", + min_value=0.01, max_value=1.0, + value=st.session_state.top_p if st.session_state.top_p is not None else 0.9, step=0.01, + key="top_p_input", + help="Contrôle la diversité en limitant les options de tokens. Plus bas = moins diversifié. 1.0 = désactivé.", + # on_change=lambda: setattr(st.session_state, '_params_changed', True) + ) + + # Handle System and Starter messages here for text task + st.session_state.system_message = st.text_area( + "Message Système / Personnalité :", + value=st.session_state.system_message, + height=100, + key="system_message_input", + help="Décrivez le rôle ou le style que l'IA de texte doit adopter. Sa capacité à suivre cette consigne dépend du modèle.", + ) + st.session_state.starter_message = st.text_area( + "Message de Bienvenue de l'IA :", + value=st.session_state.starter_message, + height=100, + key="starter_message_input", + help="Le premier message que l'IA affichera au début d'une nouvelle conversation textuelle ou après un reset.", + ) + + elif st.session_state.selected_task == 'image': + with st.expander("Ajuster les Paramètres Image (T2I)", expanded=True): + # Link input widgets directly to session state T2I parameters + st.session_state.num_inference_steps = st.number_input( + "Nombre d'étapes d'inférence :", + min_value=1, max_value=200, + value=st.session_state.num_inference_steps if st.session_state.num_inference_steps is not None else 50, step=5, + key="num_inference_steps_input", + help="Nombre d'étapes dans le processus de diffusion. Plus élevé = potentiellement meilleure qualité mais plus lent." + ) + st.session_state.guidance_scale = st.number_input( + "Échelle de guidage (CFG) :", + min_value=0.0, max_value=20.0, + value=st.session_state.guidance_scale if st.session_state.guidance_scale is not None else 7.5, step=0.5, + key="guidance_scale_input", + help="Dans quelle mesure l'IA doit suivre le prompt. Plus élevé = plus fidèle au prompt mais potentiellement moins créatif." + ) + st.session_state.image_height = st.number_input( + "Hauteur de l'image :", + min_value=128, max_value=1024, + value=st.session_state.image_height if st.session_state.image_height is not None else 512, step=64, + key="image_height_input", + help="Hauteur en pixels de l'image générée." + ) + st.session_state.image_width = st.number_input( + "Largeur de l'image :", + min_value=128, max_value=1024, + value=st.session_state.image_width if st.session_state.image_width is not None else 512, step=64, + key="image_width_input", + help="Largeur en pixels de l'image générée." + ) + st.session_state.negative_prompt = st.text_area( + "Prompt Négatif :", + value=st.session_state.negative_prompt, + height=100, + key="negative_prompt_input", + help="Ce que vous NE voulez PAS voir dans l'image." + ) + # System and Starter messages are less relevant for simple T2I, hide them + # Or repurpose system prompt for image styles? Keep it hidden for now. + + # --- Automatic history reset on parameter change --- + # This ensures consistent parameters throughout a single generation task session. + # Get current parameter values from session state + current_params = {} + if st.session_state.selected_task == 'text': + current_params = { + "max_new_tokens": st.session_state.max_response_length, + "temperature": st.session_state.temperature, + "top_p": st.session_state.top_p, + } + elif st.session_state.selected_task == 'image': + current_params = { + "num_inference_steps": st.session_state.num_inference_steps, + "guidance_scale": st.session_state.guidance_scale, + "height": st.session_state.image_height, # Note: keys match model params dict structure + "width": st.session_state.image_width, + "negative_prompt": st.session_state.negative_prompt, + } - 📧 [syebadokpo@gmail.com](mailto:syebadokpo@gmail.com) - 📞 +229 96 91 13 46 - """) + # Find the parameters stored with the *currently selected model definition* + current_selected_model_info = next((m for m in AVAILABLE_MODELS if m['id'] == st.session_state.selected_model_id), None) + default_params_for_current_model = current_selected_model_info.get('params', {}) if current_selected_model_info else {} + + # Check if current session state parameters deviate from the *initial default* parameters of the *currently selected model* + # This is a proxy for "have parameters been changed by the user since the last reset/model switch?" + # More accurate would be to compare to a _previous_ params state, but linking input widgets to state simplifies things. + # The simplest approach is: if any parameter widget value is different from what's in session state + # (which shouldn't happen due to direct linking) OR if a parameter value is detected to be different + # from the model's default *during a conversation*, trigger a reset. + # Let's simplify: rely on the "Appliquer" button for explicit reset with parameter changes. + # The logic `model_changed or task_changed` already resets history. + # A manual parameter change via slider/input *without* hitting "Appliquer" or changing model/task + # does *not* auto-reset history to avoid annoyance, but the *new* parameters from session state + # will be used for the next generation call. The "Appliquer" button is there if the user wants to guarantee a clean start with new params. + + # Using a simple flag linked to the explicit reset buttons now. + reset_just_triggered = st.session_state.get('_reset_triggered', False) + if reset_just_triggered: + st.session_state._reset_triggered = False # Clear the flag + + st.subheader("👤 Interface Utilisateur") + with st.expander("Choisir les Avatars Cyberpunk", expanded=False): + st.markdown("*Sélection des avatars personnalisés :*") + col1_avatar, col2_avatar = st.columns(2) + with col1_avatar: + # Options include those suitable for both texto utput and image output if needed + st.session_state.avatars["assistant"] = st.selectbox("Avatar IA", options=["🤖", "🎨", "✨", "💡", "🌟", "👩🏽‍⚕️", "👨🏿‍🎓", "⚙️"], index=0, key="avatar_ia_select") + with col2_avatar: + st.session_state.avatars["user"] = st.selectbox("Avatar Utilisateur", options=["👤", "👩‍💻", "👨‍🎓", "❓", "💡", "🧑‍🔧", "👩🏽‍🔬", "🕵🏽", "🧑‍🚀"], index=0, key="avatar_user_select") + + st.subheader("🔄 Gestion de la Conversation") + # Reset button that applies *all* parameters (system/starter/model) and starts new conversation + # This should happen after task/model selection logic + # FIX: Corrected f-string escaping for the help text. + if st.button("♻️ Appliquer Paramètres & Nouvelle Conversation", type="primary", help=f"Applique les paramètres actuels pour la tâche '{st.session_state.selected_task}' et démarre une nouvelle conversation ({'histoire textuelle' if st.session_state.selected_task == 'text' else \"génération d'images\"}) en effaçant l'historique."): + # Parameters are already updated in session state via key linking + # Reset history. Set type to "text" for the initial message content type. + st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message, "type": "text"}] + st.info("Paramètres appliqués et conversation réinitialisée.") + st.session_state._reset_triggered = True + st.rerun() # Rerun immediately to refresh UI + + # Clear history button - simpler, just history + if st.button("🗑️ Effacer la Conversation Actuelle", help=f"Efface l'historique de conversation pour la tâche '{st.session_state.selected_task}' mais conserve les paramètres actuels."): + # Reset history. Set type to "text" for the initial message content type. + st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message, "type": "text"}] + st.info("Conversation actuelle effacée.") + st.session_state._reset_triggered = True + st.rerun() # Rerun immediately to show empty chat with starter message + + st.subheader("📥 Exporter la Conversation") + # Check if there's actual history (more than just the initial assistant message) + export_possible = len([msg for msg in st.session_state.chat_history if msg.get("role") != "system"]) > 1 + + if "chat_history" in st.session_state and export_possible : + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + col_export1, col_export2, col_export3 = st.columns(3) + with col_export1: + st.download_button( + label="TXT", + data=format_history_to_txt(st.session_state.chat_history), + file_name=f"kolachat_{st.session_state.selected_task}_conversation_{timestamp}.txt", + mime="text/plain", + use_container_width=True + ) + with col_export2: + st.download_button( + label="JSON", + data=format_history_to_json(st.session_state.chat_history), + file_name=f"kolachat_{st.session_state.selected_task}_conversation_{timestamp}.json", + mime="application/json", + use_container_width=True + ) + with col_export3: + st.download_button( + label="Markdown", + data=format_history_to_md(st.session_state.chat_history), + file_name=f"kolachat_{st.session_state.selected_task}_conversation_{timestamp}.md", + mime="text/markdown", + use_container_width=True + ) + else: + st.caption("Aucune conversation substantielle ou exportable pour le moment.") + + st.divider() + st.subheader("Crédits") + st.markdown(""" + **Auteur :** Sidoine K. YEBADOKPO + *Expert en Analyse de Données* + + 📧 [syebadokpo@gmail.com](mailto:syebadokpo@gmail.com) + 📞 +229 96 91 13 46 + """) # ----------------------------------------------------------------------------- # Main Chat / Input / Generation Logic # ----------------------------------------------------------------------------- +# Use a container for the chat history display area with a fixed height and border chat_interface_container = st.container(height=600, border=True) with chat_interface_container: - # Display existing messages from chat_history. - for message in st.session_state.chat_history: - if message[ROLE_KEY] == ROLE_SYSTEM: - continue # Skip system messages - - avatar_type = st.session_state.avatars.get(message[ROLE_KEY], "❓") - - with st.chat_message(message[ROLE_KEY], avatar=avatar_type): - message_type = message.get(TYPE_KEY, TYPE_TEXT) # Default to text type - if message_type == TYPE_TEXT: - st.markdown(message.get(CONTENT_KEY, "")) - elif message_type == TYPE_T2I and IMAGE_DATA_KEY in message: - try: - image = Image.open(BytesIO(message[IMAGE_DATA_KEY])) - st.image(image, caption=f"Prompt: {message.get(PROMPT_KEY, 'Pas de prompt enregistré')}") - except Exception as e: - st.error(f"Erreur lors de l'affichage de l'image : {e}") - st.markdown(f"Prompt: {message.get(PROMPT_KEY, 'Pas de prompt enregistré')}") - elif message_type == TYPE_T2I_PROMPT: # User's original prompt for image in history - st.markdown(f"_Prompt Image :_ {message.get(CONTENT_KEY, '')}") - - + # Display existing messages from chat_history. + # System messages are never displayed. Messages can be text or image. + for message in st.session_state.chat_history: + if message["role"] == "system": + continue # Skip system messages + + avatar_type = st.session_state.avatars.get(message["role"], "❓") # Default avatar + + with st.chat_message(message["role"], avatar=avatar_type): + message_type = message.get("type", "text") # Default to text type + if message_type == "text": + st.markdown(message.get("content", "")) # Display text content + elif message_type == "t2i" and "image_data" in message: + # Display generated image + try: + image = Image.open(BytesIO(message["image_data"])) + st.image(image, caption=f"Prompt: {message.get('prompt', 'Pas de prompt enregistré')}") + except Exception as e: + st.error(f"Erreur lors de l'affichage de l'image : {e}") + st.markdown(f"Prompt: {message.get('prompt', 'Pas de prompt enregistré')}") # Show prompt even if image fails + # Add other message types if needed in the future (e.g., multimodal) + +# Area for the chat input box input_container = st.container() with input_container: - user_input = st.chat_input(placeholder=f"Entrez votre prompt pour la tâche '{st.session_state.selected_task}'...") + user_input = st.chat_input(placeholder=f"Entrez votre prompt pour la tâche '{st.session_state.selected_task}'...") # --- Logic to handle user input and trigger generation --- # This part runs only when user_input is NOT None (i.e., after the user hits Enter) -# AND if a reset wasn't just triggered (to prevent immediate re-processing after clear/apply buttons) -if user_input and not st.session_state._reset_triggered: - # 1. Append the user's message to the chat history - user_message_entry = {ROLE_KEY: ROLE_USER, CONTENT_KEY: user_input} - if st.session_state.selected_task == TASK_IMAGE: - user_message_entry[TYPE_KEY] = TYPE_T2I_PROMPT # User input for image generation task - else: - user_message_entry[TYPE_KEY] = TYPE_TEXT - - st.session_state.chat_history.append(user_message_entry) - - # Trigger a rerun to immediately show the user message and then process - st.rerun() +if user_input: + # 1. Append the user's message to the chat history + # Include the type based on the current task + user_message_entry = {"role": "user", "content": user_input} + # For T2I task, the user input is the prompt, so record it specifically + if st.session_state.selected_task == 'image': + user_message_entry["type"] = "t2i_prompt" # Indicate user input for T2I task type + else: # Default to text type + user_message_entry["type"] = "text" + + st.session_state.chat_history.append(user_message_entry) + + # --- Trigger a rerun to immediately show the user message --- + # This restart the script, redrawing the chat (now including the user's message). + # The code below the chat_input block will then execute on this rerun. + st.rerun() # --- This block executes on reruns, check if processing is needed --- -# Only proceed if the last message was from the user AND a reset wasn't just triggered. -if st.session_state.chat_history and st.session_state.chat_history[-1][ROLE_KEY] == ROLE_USER and not st.session_state._reset_triggered: - st.session_state._reset_triggered = False # Reset the flag after processing a new user turn if it was set by a button - - selected_model_info = next((m for m in AVAILABLE_MODELS if m['id'] == st.session_state.selected_model_id), None) - if not selected_model_info: - error_message = {ROLE_KEY: ROLE_ASSISTANT, CONTENT_KEY: f"Erreur: Modèle '{st.session_state.selected_model_id}' introuvable dans la liste. Veuillez sélectionner un modèle valide.", TYPE_KEY: TYPE_TEXT} - st.session_state.chat_history.append(error_message) - st.rerun() - else: - current_model_display_name = selected_model_info['name'] - current_model_type = selected_model_info['type'] - - assistant_response_entry = None - - with chat_interface_container: - assistant_avatar = st.session_state.avatars.get(ROLE_ASSISTANT, "❓") - with st.chat_message(ROLE_ASSISTANT, avatar=assistant_avatar): - with st.spinner(f"KolaChatBot utilise **{current_model_display_name}** pour générer... 🤔"): - if st.session_state.selected_task == TASK_TEXT: - assistant_response_entry = get_text_response( - selected_model_id=st.session_state.selected_model_id, - system_prompt=st.session_state.system_message, - full_chat_history=st.session_state.chat_history # Pass full history for context - ) - if assistant_response_entry.get(TYPE_KEY) == TYPE_TEXT: - st.markdown(assistant_response_entry.get(CONTENT_KEY, "Erreur: Réponse texte vide.")) - else: - st.error(f"Erreur inattendue lors de la génération de texte: {assistant_response_entry.get(CONTENT_KEY, 'Format de réponse incorrect.')}") - - elif st.session_state.selected_task == TASK_IMAGE: - image_prompt = st.session_state.chat_history[-1][CONTENT_KEY] - assistant_response_entry = get_image_response( - selected_model_id=st.session_state.selected_model_id, - user_prompt=image_prompt - ) - if assistant_response_entry.get(TYPE_KEY) == TYPE_T2I and IMAGE_DATA_KEY in assistant_response_entry: - try: - image = Image.open(BytesIO(assistant_response_entry[IMAGE_DATA_KEY])) - st.image(image, caption=f"Prompt: {assistant_response_entry.get(PROMPT_KEY, 'N/A')}") - except Exception as e: - st.error(f"Erreur lors de l'affichage de l'image générée: {e}") - st.markdown(f"Prompt: {assistant_response_entry.get(PROMPT_KEY, 'N/A')} (Affichage image échoué)") - elif assistant_response_entry.get(TYPE_KEY) == TYPE_TEXT: - st.error(f"Échec de la génération d'image: {assistant_response_entry.get(CONTENT_KEY, 'Raison inconnue.')}") - st.markdown(f"Prompt: {assistant_response_entry.get(PROMPT_KEY, 'N/A')}") - - if assistant_response_entry is not None: - st.session_state.chat_history.append(assistant_response_entry) \ No newline at end of file +# Check if the last message was from the user AND we haven't added an assistant response yet. +# This avoids regenerating if the last message is already from the assistant. +if st.session_state.chat_history and st.session_state.chat_history[-1]["role"] == "user": + + # Find info about the currently selected model + selected_model_info = next((m for m in AVAILABLE_MODELS if m['id'] == st.session_state.selected_model_id), None) + if not selected_model_info: + # Display error in chat if model not found (should be caught earlier, but safety) + error_message = {"role": "assistant", "content": f"Erreur: Modèle '{st.session_state.selected_model_id}' introuvable dans la liste. Veuillez sélectionner un modèle valide.", "type": "text"} + st.session_state.chat_history.append(error_message) + st.rerun() # Rerun to display error + else: + current_model_display_name = selected_model_info['name'] + current_model_type = selected_model_info['type'] + + # --- Generate the response based on the task --- + assistant_response_entry = None + + with chat_interface_container: # Place spinner within the chat container where response will appear + assistant_avatar = st.session_state.avatars.get("assistant", "❓") # Get assistant avatar + with st.chat_message("assistant", avatar=assistant_avatar): + with st.spinner(f"KolaChatBot utilise {current_model_display_name} pour générer... 🤔"): + + if st.session_state.selected_task == 'text': + # Call text generation function + assistant_response_entry = get_text_response( + selected_model_id=st.session_state.selected_model_id, + user_text=st.session_state.chat_history[-1]["content"], + system_prompt=st.session_state.system_message, + full_chat_history=st.session_state.chat_history # Pass full history for context + ) + # Display text directly in the spinner block + if assistant_response_entry.get("type") == "text": + st.markdown(assistant_response_entry.get("content", "Erreur: Réponse texte vide.")) + else: # Handle case where get_text_response might return error as non-text (e.g., internal error) + st.error(f"Erreur inattendue lors de la génération de texte: {assistant_response_entry.get('content', 'Format de réponse incorrect.')}") + + elif st.session_state.selected_task == 'image': + # Call image generation function + # User input for T2I is simply the prompt string + image_prompt = st.session_state.chat_history[-1]["content"] + assistant_response_entry = get_image_response( + selected_model_id=st.session_state.selected_model_id, + user_prompt=image_prompt + ) + # Display image or error message based on response entry + if assistant_response_entry.get("type") == "t2i" and "image_data" in assistant_response_entry: + try: + image = Image.open(BytesIO(assistant_response_entry["image_data"])) + st.image(image, caption=f"Prompt: {assistant_response_entry.get('prompt', 'N/A')}") + except Exception as e: + st.error(f"Erreur lors de l'affichage de l'image générée: {e}") + st.markdown(f"Prompt: {assistant_response_entry.get('prompt', 'N/A')} (Affichage image échoué)") + elif assistant_response_entry.get("type") == "text": # Error message returned as text + st.error(f"Échec de la génération d'image: {assistant_response_entry.get('content', 'Raison inconnue.')}") + st.markdown(f"Prompt: {assistant_response_entry.get('prompt', 'N/A')}") # Display prompt if error + + # --- Append the generated response entry to chat history --- + # Ensure we have received a valid response entry object + if assistant_response_entry is not None: + st.session_state.chat_history.append(assistant_response_entry) + # No need for another st.rerun() after appending and displaying, + # the UI is already updated. \ No newline at end of file