Spaces:
Running
Running
| ```text | |
| File: app.py | |
| ``````python | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient as HubInferenceClient # Renamed to avoid conflict | |
| import os | |
| import json | |
| import base64 | |
| from PIL import Image | |
| import io | |
| # Smolagents imports | |
| from smolagents import CodeAgent, Tool, LiteLLMModel, OpenAIServerModel, TransformersModel, InferenceClientModel as SmolInferenceClientModel | |
| from smolagents.gradio_ui import stream_to_gradio | |
| ACCESS_TOKEN = os.getenv("HF_TOKEN") | |
| print("Access token loaded.") | |
| # Function to encode image to base64 | |
| def encode_image(image_path): | |
| if not image_path: | |
| print("No image path provided") | |
| return None | |
| try: | |
| print(f"Encoding image from path: {image_path}") | |
| # If it's already a PIL Image | |
| if isinstance(image_path, Image.Image): | |
| image = image_path | |
| else: | |
| # Try to open the image file | |
| image = Image.open(image_path) | |
| # Convert to RGB if image has an alpha channel (RGBA) | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| # Encode to base64 | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| print("Image encoded successfully") | |
| return img_str | |
| except Exception as e: | |
| print(f"Error encoding image: {e}") | |
| return None | |
| # --- Smolagents Tool Definition --- | |
| try: | |
| image_generation_tool = Tool.from_space( | |
| "black-forest-labs/FLUX.1-schnell", | |
| name="image_generator", | |
| description="Generates an image from a textual prompt. Use this tool if the user asks to generate, create, or draw an image.", | |
| token=ACCESS_TOKEN # Pass token if the space might be private or has rate limits | |
| ) | |
| print("Image generation tool loaded successfully.") | |
| SMOLAGENTS_TOOLS = [image_generation_tool] | |
| except Exception as e: | |
| print(f"Error loading image generation tool: {e}. Proceeding without it.") | |
| SMOLAGENTS_TOOLS = [] | |
| def respond( | |
| message, | |
| image_files, # Changed parameter name and structure | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| frequency_penalty, | |
| seed, | |
| provider, | |
| custom_api_key, | |
| custom_model, | |
| model_search_term, | |
| selected_model | |
| ): | |
| print(f"Received message: {message}") | |
| print(f"Received {len(image_files) if image_files else 0} images") | |
| # print(f"History: {history}") # Can be very verbose | |
| print(f"System message: {system_message}") | |
| print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}") | |
| print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}") | |
| print(f"Selected provider: {provider}") | |
| print(f"Custom API Key provided: {bool(custom_api_key.strip())}") | |
| print(f"Selected model (custom_model): {custom_model}") | |
| print(f"Model search term: {model_search_term}") | |
| print(f"Selected model from radio: {selected_model}") | |
| # Determine which token to use | |
| token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN | |
| if custom_api_key.strip() != "": | |
| print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication") | |
| else: | |
| print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication") | |
| # Determine which model to use, prioritizing custom_model if provided | |
| model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model | |
| print(f"Model selected for LLM: {model_to_use}") | |
| # Prepare parameters for the LLM | |
| llm_parameters = { | |
| "max_tokens": max_tokens, # For LiteLLMModel, OpenAIServerModel | |
| "max_new_tokens": max_tokens, # For TransformersModel, InferenceClientModel | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "frequency_penalty": frequency_penalty, | |
| } | |
| if seed != -1: | |
| llm_parameters["seed"] = seed | |
| # Initialize the smolagents Model | |
| # For simplicity, we'll use InferenceClientModel if provider is hf-inference, | |
| # otherwise LiteLLMModel which supports many providers. | |
| # You might want to add more sophisticated logic to select the right smolagents Model class. | |
| if provider == "hf-inference" or provider is None or provider == "": # provider can be None if custom_model is a URL | |
| smol_model = SmolInferenceClientModel( | |
| model_id=model_to_use, | |
| token=token_to_use, | |
| provider=provider if provider else None, # Pass provider only if it's explicitly set and not hf-inference default | |
| **llm_parameters | |
| ) | |
| print(f"Using SmolInferenceClientModel for LLM with provider: {provider or 'default'}") | |
| else: | |
| # Assuming other providers might be LiteLLM compatible | |
| # LiteLLM uses `model` for model_id and `api_key` for token | |
| smol_model = LiteLLMModel( | |
| model_id=f"{provider}/{model_to_use}" if provider else model_to_use, # LiteLLM often expects provider/model_name | |
| api_key=token_to_use, | |
| **llm_parameters | |
| ) | |
| print(f"Using LiteLLMModel for LLM with provider: {provider}") | |
| # Initialize smolagent | |
| # We'll use CodeAgent as it's generally more powerful. | |
| # The system_message from the UI will be part of the task for the agent. | |
| agent_task = message | |
| if system_message and system_message.strip(): | |
| agent_task = f"System Instructions: {system_message}\n\nUser Task: {message}" | |
| print(f"Initializing CodeAgent with model: {model_to_use}") | |
| agent = CodeAgent( | |
| tools=SMOLAGENTS_TOOLS, # Use the globally defined tools | |
| model=smol_model, | |
| stream_outputs=True # Important for streaming | |
| ) | |
| print("CodeAgent initialized.") | |
| # Prepare multimodal inputs for the agent if images are present | |
| agent_images = [] | |
| if image_files and len(image_files) > 0: | |
| for img_path in image_files: | |
| if img_path: | |
| try: | |
| # Smolagents expects PIL Image objects for images | |
| pil_image = Image.open(img_path) | |
| agent_images.append(pil_image) | |
| except Exception as e: | |
| print(f"Error opening image for agent: {e}") | |
| print(f"Prepared {len(agent_images)} images for the agent.") | |
| # Start with an empty string to build the response as tokens stream in | |
| response_text = "" | |
| print(f"Running agent with task: {agent_task}") | |
| try: | |
| # Use stream_to_gradio for handling agent's streaming output | |
| # The history needs to be converted to the format smolagents expects if we want to continue conversations. | |
| # For now, we'll pass reset=True to simplify, meaning each call is a new conversation for the agent. | |
| # To support conversation history with the agent, `history` needs to be transformed into agent.memory.steps | |
| # or passed appropriately. The `stream_to_gradio` function expects the agent's internal stream. | |
| # Simplified history for agent (if needed, but stream_to_gradio handles Gradio's history) | |
| # For `agent.run`, we don't directly pass Gradio's history. | |
| # `stream_to_gradio` will yield messages that Gradio's chatbot can append. | |
| # The `stream_to_gradio` function itself is a generator. | |
| # It takes the agent and task, and yields Gradio-compatible chat messages. | |
| # The `bot` function in Gradio needs to yield these messages. | |
| # The `respond` function is already a generator, so we can yield from `stream_to_gradio`. | |
| # Gradio's history (list of tuples) is not directly used by agent.run() | |
| # Instead, the agent's own memory would handle conversational context if reset=False. | |
| # Here, we'll let stream_to_gradio handle the output formatting. | |
| print("Streaming response from agent...") | |
| for content_chunk in stream_to_gradio( | |
| agent, | |
| task=agent_task, | |
| task_images=agent_images if agent_images else None, | |
| reset_agent_memory=True # For simplicity, treat each interaction as new for the agent | |
| ): | |
| # stream_to_gradio yields either a string (for text delta) or a ChatMessage object | |
| if isinstance(content_chunk, str): # This is a text delta | |
| response_text += content_chunk | |
| yield response_text | |
| elif hasattr(content_chunk, 'content'): # This is a ChatMessage object | |
| if isinstance(content_chunk.content, dict) and 'path' in content_chunk.content: # Image/Audio | |
| # Gradio's chatbot can handle dicts for files directly if msg.submit is used | |
| # For streaming, we yield the path or a markdown representation | |
| yield f"" | |
| elif isinstance(content_chunk.content, str): | |
| response_text = content_chunk.content # Replace if it's a full message | |
| yield response_text | |
| else: # Should not happen with stream_to_gradio's typical output | |
| print(f"Unexpected chunk type from stream_to_gradio: {type(content_chunk)}") | |
| yield str(content_chunk) | |
| print("\nCompleted response generation from agent.") | |
| except Exception as e: | |
| print(f"Error during agent execution: {e}") | |
| response_text += f"\nError: {str(e)}" | |
| yield response_text | |
| # Function to validate provider selection based on BYOK | |
| def validate_provider(api_key, provider): | |
| if not api_key.strip() and provider != "hf-inference": | |
| return gr.update(value="hf-inference") | |
| return gr.update(value=provider) | |
| # GRADIO UI | |
| with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo: | |
| # Create the chatbot component | |
| chatbot = gr.Chatbot( | |
| height=600, | |
| show_copy_button=True, | |
| placeholder="Select a model and begin chatting. Now supports multiple inference providers, multimodal inputs, and image generation tool.", | |
| layout="panel", | |
| show_share_button=True # Added for easy sharing | |
| ) | |
| print("Chatbot interface created.") | |
| # Multimodal textbox for messages (combines text and file uploads) | |
| msg = gr.MultimodalTextbox( | |
| placeholder="Type a message or upload images... (e.g., 'generate an image of a cat playing chess')", | |
| show_label=False, | |
| container=False, | |
| scale=12, | |
| file_types=["image"], | |
| file_count="multiple", | |
| sources=["upload"] | |
| ) | |
| # Create accordion for settings | |
| with gr.Accordion("Settings", open=False): | |
| # System message | |
| system_message_box = gr.Textbox( | |
| value="You are a helpful AI assistant that can understand images and text. If asked to generate an image, use the available image_generator tool.", | |
| placeholder="You are a helpful assistant.", | |
| label="System Prompt" | |
| ) | |
| # Generation parameters | |
| with gr.Row(): | |
| with gr.Column(): | |
| max_tokens_slider = gr.Slider( | |
| minimum=1, | |
| maximum=4096, | |
| value=1024, # Increased default for potentially longer agent outputs | |
| step=1, | |
| label="Max tokens" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=4.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-P" | |
| ) | |
| with gr.Column(): | |
| frequency_penalty_slider = gr.Slider( | |
| minimum=-2.0, | |
| maximum=2.0, | |
| value=0.0, | |
| step=0.1, | |
| label="Frequency Penalty" | |
| ) | |
| seed_slider = gr.Slider( | |
| minimum=-1, | |
| maximum=65535, | |
| value=-1, | |
| step=1, | |
| label="Seed (-1 for random)" | |
| ) | |
| # Provider selection | |
| providers_list = [ | |
| "hf-inference", # Default Hugging Face Inference | |
| "cerebras", # Cerebras provider | |
| "together", # Together AI | |
| "sambanova", # SambaNova | |
| "novita", # Novita AI | |
| "cohere", # Cohere | |
| "fireworks-ai", # Fireworks AI | |
| "hyperbolic", # Hyperbolic | |
| "nebius", # Nebius | |
| # Add other providers supported by LiteLLM if desired | |
| ] | |
| provider_radio = gr.Radio( | |
| choices=providers_list, | |
| value="hf-inference", | |
| label="Inference Provider", | |
| ) | |
| # New BYOK textbox | |
| byok_textbox = gr.Textbox( | |
| value="", | |
| label="BYOK (Bring Your Own Key)", | |
| info="Enter a custom Hugging Face API key here. When empty, only 'hf-inference' provider can be used. For other providers, this key will be used as their respective API key.", | |
| placeholder="Enter your API token", | |
| type="password" # Hide the API key for security | |
| ) | |
| # Custom model box | |
| custom_model_box = gr.Textbox( | |
| value="", | |
| label="Custom Model", | |
| info="(Optional) Provide a custom Hugging Face model path (e.g., 'meta-llama/Llama-3.3-70B-Instruct') or a model name compatible with the selected provider. Overrides any selected featured model.", | |
| placeholder="meta-llama/Llama-3.3-70B-Instruct" | |
| ) | |
| # Model search | |
| model_search_box = gr.Textbox( | |
| label="Filter Models", | |
| placeholder="Search for a featured model...", | |
| lines=1 | |
| ) | |
| # Featured models list | |
| models_list = [ | |
| "meta-llama/Llama-3.2-11B-Vision-Instruct", | |
| "meta-llama/Llama-3.3-70B-Instruct", | |
| "meta-llama/Llama-3.1-70B-Instruct", | |
| "meta-llama/Llama-3.0-70B-Instruct", | |
| "meta-llama/Llama-3.2-3B-Instruct", | |
| "meta-llama/Llama-3.2-1B-Instruct", | |
| "meta-llama/Llama-3.1-8B-Instruct", | |
| "NousResearch/Hermes-3-Llama-3.1-8B", | |
| "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
| "mistralai/Mistral-Nemo-Instruct-2407", | |
| "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "mistralai/Mistral-7B-Instruct-v0.3", | |
| "mistralai/Mistral-7B-Instruct-v0.2", | |
| "Qwen/Qwen3-235B-A22B", | |
| "Qwen/Qwen3-32B", | |
| "Qwen/Qwen2.5-72B-Instruct", | |
| "Qwen/Qwen2.5-3B-Instruct", | |
| "Qwen/Qwen2.5-0.5B-Instruct", | |
| "Qwen/QwQ-32B", | |
| "Qwen/Qwen2.5-Coder-32B-Instruct", | |
| "microsoft/Phi-3.5-mini-instruct", | |
| "microsoft/Phi-3-mini-128k-instruct", | |
| "microsoft/Phi-3-mini-4k-instruct", | |
| ] | |
| featured_model_radio = gr.Radio( | |
| label="Select a model below (or specify a custom one above)", | |
| choices=models_list, | |
| value="meta-llama/Llama-3.2-11B-Vision-Instruct", # Default to a multimodal model | |
| interactive=True | |
| ) | |
| gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)") | |
| # Chat history state | |
| chat_history = gr.State([]) | |
| # Function to filter models | |
| def filter_models(search_term): | |
| print(f"Filtering models with search term: {search_term}") | |
| filtered = [m for m in models_list if search_term.lower() in m.lower()] | |
| print(f"Filtered models: {filtered}") | |
| return gr.update(choices=filtered) | |
| # Function to set custom model from radio (actually, sets the selected_model which is then overridden by custom_model_box if filled) | |
| def set_selected_model_from_radio(selected): | |
| print(f"Featured model selected: {selected}") | |
| # This function's output will be one of the inputs to `respond` | |
| return selected | |
| # Function for the chat interface | |
| def user(user_message_input, history): | |
| # user_message_input is a dict from MultimodalTextbox: {"text": str, "files": list[str]} | |
| print(f"User input received: {user_message_input}") | |
| text_content = user_message_input.get("text", "").strip() | |
| files = user_message_input.get("files", []) | |
| if not text_content and not files: | |
| print("Empty message, skipping history update.") | |
| return history # Or gr.skip() if Gradio version supports it well | |
| # Append to Gradio's history format | |
| # For multimodal, Gradio expects a list of (text, file_path) tuples or (None, file_path) | |
| # We will represent this as a single user turn which might have text and multiple images. | |
| # The `respond` function will then parse this. | |
| # Gradio's Chatbot can display images if the message is a tuple (None, filepath) | |
| # or if text contains markdown like  | |
| current_turn_display = [] | |
| if text_content: | |
| current_turn_display.append(text_content) | |
| if files: | |
| for file_path in files: | |
| current_turn_display.append((file_path,)) # Tuple for Gradio to recognize as file | |
| if not current_turn_display: # Should not happen if we check above | |
| return history | |
| # For simplicity in history, we'll just append the text and a note about images. | |
| # The actual image data is passed separately to `respond`. | |
| display_message = text_content | |
| if files: | |
| display_message += f" ({len(files)} image(s) uploaded)" | |
| history.append([display_message, None]) | |
| return history | |
| # Define bot response function | |
| def bot(history, system_msg, max_tokens_val, temperature_val, top_p_val, freq_penalty_val, seed_val, provider_val, api_key_val, custom_model_val, search_term_val, selected_model_val, request: gr.Request): | |
| if not history or not history[-1][0]: # If no user message | |
| yield history | |
| return | |
| # The user's latest input is in history[-1][0] | |
| # The MultimodalTextbox sends a dict: {"text": str, "files": list[str]} | |
| # However, our `user` function above simplifies this for display in `chatbot`. | |
| # We need to retrieve the original input from the request if possible, or parse history. | |
| # For simplicity with Gradio's streaming and history, we'll re-parse the last user message. | |
| # This is not ideal but works for this setup. | |
| last_user_turn_display = history[-1][0] | |
| # This is a simplified parsing. A more robust way would be to pass | |
| # the raw MultimodalTextbox output to `bot` directly. | |
| user_text_content = "" | |
| user_image_files = [] | |
| if isinstance(last_user_turn_display, str): | |
| # Check if it's a simple text or a text with image count | |
| img_count_match = re.search(r" \((\d+) image\(s\) uploaded\)$", last_user_turn_display) | |
| if img_count_match: | |
| user_text_content = last_user_turn_display[:img_count_match.start()] | |
| # We can't get back the actual file paths from this string alone. | |
| # This part needs the raw input from MultimodalTextbox. | |
| # For now, we'll assume image_files are passed correctly to `respond` | |
| # This means `msg.submit` should pass `msg` directly to `respond`'s `message` param. | |
| else: | |
| user_text_content = last_user_turn_display | |
| # The `msg` (MultimodalTextbox) component's value is what we need for image_files | |
| # We assume `msg.value` is implicitly passed or accessible via `request` if Gradio supports it, | |
| # or it should be an explicit input to `bot`. | |
| # For this implementation, we rely on `msg` being passed to `respond` via the `submit` chain. | |
| # The `history` argument to `bot` is for the chatbot display. | |
| # The actual call to `respond` will happen via the `msg.submit` chain. | |
| # This `bot` function is primarily for updating the chatbot display. | |
| history[-1][1] = "" # Clear previous bot response | |
| # `respond` is a generator. We need to iterate through its yields. | |
| # The `msg` component's value (which includes text and files) is the first argument to `respond`. | |
| # We need to ensure that `msg` is correctly passed. | |
| # The current `msg.submit` passes `msg` (the component itself) to `user`, then `user`'s output to `bot`. | |
| # This is problematic for getting the raw files. | |
| # Correct approach: `msg.submit` should pass `msg` (value) to `respond` (or a wrapper). | |
| # Let's assume `respond` will be called correctly by the `msg.submit` chain. | |
| # This `bot` function will just yield the history updates. | |
| # The actual generation is now handled by `msg.submit(...).then(respond, ...)` | |
| # This `bot` function is mostly a placeholder in the new structure if `respond` directly yields to chatbot. | |
| # However, Gradio's `chatbot.then(bot, ...)` expects `bot` to be the generator. | |
| # Re-structuring: `msg.submit` calls `user` to update history for display. | |
| # Then, `user`'s output (which is just `history`) is passed to `bot`. | |
| # `bot` then calls `respond` with all necessary parameters. | |
| # Extract the latest user message components (text and files) | |
| # This is tricky because `history` only has the display string. | |
| # We need the raw `msg` value. | |
| # The `request: gr.Request` can sometimes hold component values if using `gr.Interface`. | |
| # For Blocks, it's better to pass `msg` directly. | |
| # Let's assume `user_text_content` and `user_image_files` are correctly extracted | |
| # from the `msg` component's value when `respond` is called. | |
| # The `bot` function here will iterate over what `respond` yields. | |
| # The `message` param for `respond` should be the raw output of `msg` | |
| # So, `msg` (the component) should be an input to `bot`. | |
| # Then `bot` extracts `text` and `files` from `msg.value` (or `msg` if it's already the value). | |
| # The `msg.submit` chain needs to be: | |
| # msg.submit(fn=user_interaction_handler, inputs=[msg, chatbot, ...other_params...], outputs=[chatbot]) | |
| # where user_interaction_handler calls `user` then `respond`. | |
| # For now, let's assume `respond` is correctly called by the `msg.submit` chain | |
| # and this `bot` function is what updates the chatbot display. | |
| # The `inputs` to `bot` in `msg.submit(...).then(bot, inputs=[...])` are crucial. | |
| # The `message` and `image_files` for `respond` will come from the `msg` component. | |
| # The `history` for `respond` will be `history[:-1]` (all but the current user turn). | |
| # This `bot` function is essentially the core of `respond` now. | |
| # It needs `msg_value` as an input. | |
| # Let's rename this function to reflect it's the main generation logic | |
| # and ensure it gets the raw `msg` value. | |
| # The Gradio `msg.submit` will call a wrapper that then calls this. | |
| # For simplicity, we'll assume `respond` is called correctly by the chain. | |
| # This `bot` function is what `chatbot.then(bot, ...)` uses. | |
| # The `history` object here is the one managed by Gradio's Chatbot. | |
| # `history[-1][0]` is the user's latest displayed message. | |
| # `history[-1][1]` is where the bot's response goes. | |
| # The `respond` function needs the raw message and files. | |
| # The `msg` component itself should be an input to this `bot` function. | |
| # Let's adjust the `msg.submit` call later. | |
| # For now, this `bot` function is the generator that `chatbot.then()` expects. | |
| # It will internally call `respond`. | |
| # The `message` and `image_files` for `respond` must be sourced from the `msg` component's value, | |
| # not from `history[-1][0]`. | |
| # This function signature is what `chatbot.then(bot, ...)` will use. | |
| # The `inputs` to this `bot` must be correctly specified in `msg.submit(...).then(bot, inputs=...)`. | |
| # `msg_input` should be the value of the `msg` MultimodalTextbox. | |
| # Let's assume `msg_input` is correctly passed as the first argument to this `bot` function. | |
| # We'll rename `history` to `chatbot_history` to avoid confusion. | |
| # The `msg.submit` chain should be: | |
| # 1. `user` function: takes `msg_input`, `chatbot_history` -> updates `chatbot_history` for display, returns raw `msg_input` and `chatbot_history[:-1]` for `respond`. | |
| # 2. `respond` function: takes raw `msg_input`, `history_for_respond`, and other params -> yields response chunks. | |
| # Simpler: `msg.submit` calls `respond_wrapper` which handles history and calls `respond`. | |
| # The current structure: `msg.submit` calls `user`, then `bot`. | |
| # `user` appends user's input to `chatbot` (history). | |
| # `bot` gets this updated `chatbot` (history). | |
| # `bot` needs to extract the latest user input (text & files) to pass to `respond`. | |
| # This is difficult because `history` only has display strings. | |
| # Solution: `msg` (the component's value) must be passed to `bot`. | |
| # Let's adjust the `msg.submit` later. For now, assume `message_and_files_input` is passed. | |
| # This function's signature for `chatbot.then(bot, ...)`: | |
| # bot(chatbot_history, system_msg, ..., msg_input_value) | |
| # The `msg_input_value` will be the first argument if we adjust the `inputs` list. | |
| # Let's assume the first argument `chatbot_history` is the chatbot's state. | |
| # The actual user input (text + files) needs to be passed separately. | |
| # The `inputs` to `bot` in the `.then(bot, inputs=[...])` call must include `msg`. | |
| # If `respond` is called directly by `msg.submit().then()`, then `respond` itself is the generator. | |
| # The `chatbot` component updates based on what `respond` yields. | |
| # The current `msg.submit` structure is: | |
| # .then(user, [msg, chatbot], [chatbot]) <- `user` updates chatbot with user's message | |
| # .then(bot, [chatbot, ...other_params...], [chatbot]) <- `bot` generates response | |
| # `bot` needs the raw `msg` value. Let's add `msg` as an input to `bot`. | |
| # The `inputs` list for `.then(bot, ...)` will need to include `msg`. | |
| # The `message` and `image_files` for `respond` should come from `msg_val` (the value of the msg component) | |
| # `history_for_api` should be `chatbot_history[:-1]` | |
| # The `chatbot` variable passed to `bot` is the current state of the Chatbot UI. | |
| # `chatbot[-1][0]` is the latest user message displayed. | |
| # `chatbot[-1][1]` is where the bot's response will be streamed. | |
| # We need the raw `msg` value. Let's assume it's passed as an argument to `bot`. | |
| # The `inputs` in `.then(bot, inputs=[msg, chatbot, ...])` | |
| # The `respond` function will be called with: | |
| # - message: text from msg_val | |
| # - image_files: files from msg_val | |
| # - history: chatbot_history[:-1] (all previous turns) | |
| # This `bot` function is the one that `chatbot.then()` will call. | |
| # It needs `msg_val` as an input. | |
| # The `inputs` for this `bot` function in the Gradio chain will be: | |
| # [chatbot, system_message_box, ..., msg] | |
| # So, `msg_val` will be the last parameter. | |
| msg_val = history.pop('_msg_val_temp_') # Retrieve the raw msg value | |
| raw_text_input = msg_val.get("text", "") | |
| raw_file_inputs = msg_val.get("files", []) | |
| # The history for the API should be all turns *before* the current user input | |
| history_for_api = [turn for turn in history[:-1]] # all but the last (current) turn | |
| history[-1][1] = "" # Clear placeholder for bot response | |
| for chunk in respond( | |
| message=raw_text_input, | |
| image_files=raw_file_inputs, | |
| history=history_for_api, # Pass history *before* current user turn | |
| system_message=system_msg, | |
| max_tokens=max_tokens_val, | |
| temperature=temperature_val, | |
| top_p=top_p_val, | |
| frequency_penalty=freq_penalty_val, | |
| seed=seed_val, | |
| provider=provider_val, | |
| custom_api_key=api_key_val, | |
| custom_model=custom_model_val, | |
| selected_model=selected_model_val, # selected_model is now the one from radio | |
| model_search_term=search_term_val # Though search_term is not directly used by respond | |
| ): | |
| history[-1][1] = chunk # Stream to the last message's bot part | |
| yield history | |
| # Event handlers | |
| # We need to pass the raw `msg` value to the `bot` function. | |
| # We can temporarily store it in the `history` state object if Gradio allows modifying state objects directly. | |
| # A cleaner way is to have a single handler function. | |
| def combined_user_and_bot(msg_val, chatbot_history, system_msg, max_tokens_val, temperature_val, top_p_val, freq_penalty_val, seed_val, provider_val, api_key_val, custom_model_val, search_term_val, selected_model_val): | |
| # 1. Call user to update chatbot display | |
| updated_chatbot_history = user(msg_val, chatbot_history) | |
| yield updated_chatbot_history # Show user message immediately | |
| # 2. Call respond (which is now the core generation logic) | |
| # The history for `respond` should be `updated_chatbot_history[:-1]` | |
| # Clear placeholder for bot's response in the last turn | |
| if updated_chatbot_history and updated_chatbot_history[-1] is not None: | |
| updated_chatbot_history[-1][1] = "" | |
| history_for_api = updated_chatbot_history[:-1] if updated_chatbot_history else [] | |
| for chunk in respond( | |
| message=msg_val.get("text", ""), | |
| image_files=msg_val.get("files", []), | |
| history=history_for_api, | |
| system_message=system_msg, | |
| max_tokens=max_tokens_val, | |
| temperature=temperature_val, | |
| top_p=top_p_val, | |
| frequency_penalty=freq_penalty_val, | |
| seed=seed_val, | |
| provider=provider_val, | |
| custom_api_key=api_key_val, | |
| custom_model=custom_model_val, | |
| selected_model=selected_model_val, | |
| model_search_term=search_term_val | |
| ): | |
| if updated_chatbot_history and updated_chatbot_history[-1] is not None: | |
| updated_chatbot_history[-1][1] = chunk | |
| yield updated_chatbot_history | |
| msg.submit( | |
| combined_user_and_bot, | |
| [msg, chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, | |
| frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, | |
| model_search_box, featured_model_radio], # Pass `msg` (value of MultimodalTextbox) | |
| [chatbot] | |
| ).then( | |
| lambda: {"text": "", "files": []}, # Clear inputs after submission | |
| None, | |
| [msg] | |
| ) | |
| # Connect the model filter to update the radio choices | |
| model_search_box.change( | |
| fn=filter_models, | |
| inputs=model_search_box, | |
| outputs=featured_model_radio | |
| ) | |
| print("Model search box change event linked.") | |
| # Connect the featured model radio to update the custom model box (if user selects from radio, it populates custom_model_box) | |
| featured_model_radio.change( | |
| fn=lambda selected_model_from_radio: selected_model_from_radio, # Directly pass the value | |
| inputs=featured_model_radio, | |
| outputs=custom_model_box # This makes custom_model_box reflect the radio selection | |
| # User can then override it by typing. | |
| ) | |
| print("Featured model radio button change event linked.") | |
| # Connect the BYOK textbox to validate provider selection | |
| byok_textbox.change( | |
| fn=validate_provider, | |
| inputs=[byok_textbox, provider_radio], | |
| outputs=provider_radio | |
| ) | |
| print("BYOK textbox change event linked.") | |
| # Also validate provider when the radio changes to ensure consistency | |
| provider_radio.change( | |
| fn=validate_provider, | |
| inputs=[byok_textbox, provider_radio], | |
| outputs=provider_radio | |
| ) | |
| print("Provider radio button change event linked.") | |
| print("Gradio interface initialized.") | |
| if __name__ == "__main__": | |
| print("Launching the demo application.") | |
| demo.launch(show_api=True, share=True) # Added share=True for easier testing |