import asyncio import gradio as gr import os from config import session_keys from mcp_client_wrapper import MCPClientWrapper import logging from utils import get_or_create_session, reset_session, decode_base64_image, cleanup_old_sessions from custom_html_render import render_face_data_html import asyncio from llm import LLM_Client import numpy as np import shutil logging.basicConfig(level=logging.INFO) logger = logging.getLogger(f"😎 {__name__}") # Do not wonder why I uses emoj in logger - It is visually easier to track def clear_data(sessionId): folder_path = f"tmp/{sessionId}" message = "" try: session_keys[sessionId] = {} message += "API keys and model selection cleared! If you chat without adding new, there will be error." except: message += "Could not del API keys and model selection" try: if os.path.exists(folder_path): shutil.rmtree(folder_path) message += "Image Data and Color Analysis Data are cleared." except: message += "Failed to remove Image Data and Color Analysis Data" return message def show_images(*image_urls): updates = [] for i in range(5): try: if i < len(image_urls): updates.append(show_image(image_urls[i])) else: updates.append(gr.update(visible=False, value=None)) except Exception as e: logger.warning(f"Failed to show image at index {i}: {e}") updates.append(gr.update(visible=False, value=None)) return updates def show_image(image_url): if isinstance(image_url, np.ndarray): return gr.update(value=image_url, visible=True) if isinstance(image_url, str): if not image_url.strip(): return gr.update(visible=True) try: img = f"https://ysharma-sanasprint.hf.space/gradio_api/file={image_url}" return gr.update(value=img, visible=True) except Exception as e: logger.warning(f"Failed to decode image: {e}") return gr.update(visible=True) return gr.update(visible=True) def check_keys_and_toggle_inputs(session_id): settings = session_keys.get(session_id, {}) openai_key = settings.get("OPENAI_API_KEY") nebius_key = settings.get("NEBIUS_API_KEY") provider = settings.get("provider", "OpenAI") if provider == "OpenAI": if not openai_key: return ( gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), "⚠️ Please go to **Settings** and add your OpenAI API key to start." ) try: _ = LLM_Client(session_id, sourceAI="openai", api_key=openai_key) except Exception as e: return ( gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), "⚠️ Please go to **Settings** and add your OpenAI API key to start." ) if provider == "Nebius": if not nebius_key: return ( gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), "⚠️ Please go to **Settings** and add your Nebius API key to start." ) try: _ = LLM_Client(session_id, sourceAI="nebius", api_key=nebius_key) except Exception as e: return ( gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), "⚠️ Please go to **Settings** and add your Nebius API key to start." ) return ( gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), f"Using {provider} Provider" ) def gradio_interface(): with gr.Blocks(title="StyleMatch Assistant – Find your colors. Find your style.", css=".custom-image { height: 200px !important; width: 300px !important; object-fit: contain; } .custom-image2 { height: 250px !important; object-fit: contain; }") as demo: session_id_state = gr.State(str(get_or_create_session())) client_state = gr.State() gr.Markdown("# StyleMatch Assistant – Find your colors. Find your style.") def message_handler(session_id, message, history, image_input, client): if client is None: logger.warning(f"⚠️ No client found in state for session {session_id[:5]}") return history + [{"role": "assistant", "content": "⚠️ Please save your API key first in Settings tab."}], gr.Textbox(value=""), *[None]*5, render_face_data_html({}) return client.process_message(session_id, message, history, image_input) with gr.Tab("Chat"): key_status = gr.Markdown( "⚠️ Please add OpenAI API key to continue." ) with gr.Row(equal_height=True): with gr.Column(scale=1): image_input = gr.Image( label="Face Image", visible=True, interactive=False, elem_classes="custom-image2", type="filepath", scale=2 ) with gr.Row(scale=3): face_data_display = gr.HTML(label="Face Analysis Result", value=render_face_data_html({})) with gr.Column(scale=3): chatbot = gr.Chatbot( value=[], height=500, type="messages", show_copy_button=True, avatar_images=("asset/avatar.png", "asset/bot.png"), scale=4 ) with gr.Row(equal_height=True): msg = gr.Textbox( label="What would you like to know?", placeholder="What color do I match with?", scale=2, interactive=False, ) send_btn = gr.Button("Send", variant="primary", size="sm", interactive=False) image_outputs = [] with gr.Row(equal_height=True): for i in range(5): img = gr.Image( label=f"Product Image {i+1}", visible=True, interactive=False, elem_classes="custom-image" ) image_outputs.append(img) def bind_message_submission(trigger): return trigger( fn=message_handler, inputs=[session_id_state, msg, chatbot, image_input, client_state], outputs=[chatbot, msg] + image_outputs + [face_data_display] ).then( fn=show_images, inputs=image_outputs, outputs=image_outputs ) bind_message_submission(msg.submit) bind_message_submission(send_btn.click) with gr.Tab("Settings"): clear_data_btn = gr.Button("Clear Data", variant="primary", size="sm", interactive=True) data_clear_status = gr.Markdown("") gr.Markdown("## 🔧 Model Settings") provider_selector = gr.Dropdown( label="Provider", choices=["OpenAI", "Nebius"], value="OpenAI" ) tool_call_selector = gr.Dropdown( label="Tool Call Model (OpenAI only)", choices=["gpt-4o-mini", "gpt-4.1-mini"], visible=True, value="gpt-4o-mini" ) response_model_selector = gr.Dropdown( label="Response Model", choices=["gpt-4o-mini", "gpt-4.1-mini"], value="gpt-4o-mini" ) vllm_model_selector = gr.Dropdown( label="VLLM Model", choices=["gpt-4o-mini", "gpt-4.1-mini"], value="gpt-4o-mini" ) model_feedback = gr.Markdown("") set_model_btn = gr.Button("Set Models") gr.Markdown("## 🔐 API Key Settings") with gr.Column(): openai_key_input = gr.Textbox( label="OpenAI API Key (Required)", placeholder="Enter your OpenAI key", type="password" ) nebius_key_input = gr.Textbox( label="Nebius API Key (Optional)", placeholder="Enter your Nebius API Key", type="password" ) set_keys_btn = gr.Button("🔐 Save Keys") key_feedback = gr.Markdown("") clear_data_btn.click(fn=clear_data, inputs=session_id_state, outputs=data_clear_status) def set_user_keys(session_id, provider, openai_key, nebius_key, tool_model, response_model, vllm_model): openai_key = openai_key.strip() nebius_key = nebius_key.strip() provider = provider.strip() if not openai_key: # OpenAI is a must return "⚠️ Please provide a valid OpenAI API key." try: _ = LLM_Client(session_id, sourceAI="openai", api_key=openai_key) except Exception as e: logger.warning(f"❌ Invalid OpenAI key: {e}") return "⚠️ OpenAI key is invalid." if provider == "Nebius": #Only if Nebius is selected if not nebius_key: return "⚠️ Please provide a valid Nebius API key." try: _ = LLM_Client(session_id, sourceAI="nebius", api_key=nebius_key) except Exception as e: logger.warning(f"❌ Invalid Nebius key: {e}") return "⚠️ Nebius key is invalid." # Save only if validation passed session_keys[session_id] = { "OPENAI_API_KEY": openai_key, "NEBIUS_API_KEY": nebius_key, "provider": provider } session_keys[session_id]["provider"] = provider.capitalize() session_keys[session_id]["tool_call_model"] = tool_model session_keys[session_id]["response_model"] = response_model session_keys[session_id]["VLLM_model"] = vllm_model logger.info(f"✅ API keys set for session {session_id[:5]}") return f"✅ Keys saved for session `{session_id[:5]}`." def init_client_after_key_save(session_id): client = MCPClientWrapper(session_id=session_id) connect_msg = client.connect() logger.info(connect_msg) return client, f"✅ Client ready for session {session_id[:5]}" set_keys_btn.click( fn=set_user_keys, inputs=[session_id_state, provider_selector, openai_key_input, nebius_key_input, tool_call_selector, response_model_selector, vllm_model_selector], outputs=[key_feedback] ).then( fn=check_keys_and_toggle_inputs, inputs=[session_id_state], outputs=[msg, send_btn, image_input, key_status] ).then( fn=init_client_after_key_save, inputs=[session_id_state], outputs=[client_state, key_feedback] ) def update_model_options(provider): if provider == "OpenAI": return ( gr.update( choices=["gpt-4o-mini", "gpt-4.1-mini"], value="gpt-4o-mini" ), gr.update( choices=["gpt-4o-mini", "gpt-4.1-mini"], value="gpt-4o-mini" ), gr.update( choices=["gpt-4o-mini", "gpt-4.1-mini"], value="gpt-4o-mini" ) ) else: return ( gr.update( choices=["gpt-4o-mini", "gpt-4.1-mini"], value="gpt-4o-mini" ), gr.update( choices=["mistralai/Mistral-Nemo-Instruct-2407"], value="mistralai/Mistral-Nemo-Instruct-2407" ), gr.update( choices=["Qwen/Qwen2.5-VL-72B-Instruct"], value="Qwen/Qwen2.5-VL-72B-Instruct" ) ) provider_selector.change( fn=lambda provider, session_id: (*update_model_options(provider), *check_keys_and_toggle_inputs(session_id)), inputs=[provider_selector, session_id_state], outputs=[ tool_call_selector, response_model_selector, vllm_model_selector, msg, send_btn, image_input, key_status ] ) def set_user_models(session_id, provider, tool_model, response_model, vllm_model): openai_key = session_keys.get(session_id, {}).get("OPENAI_API_KEY") nebius_key = session_keys.get(session_id, {}).get("NEBIUS_API_KEY") provider = provider.strip().lower() if not openai_key: return "⚠️ Please enter a valid OpenAI API key before setting the model." try: _ = LLM_Client(session_id, sourceAI="openai", api_key=openai_key) except Exception as e: logger.warning(f"❌ Invalid OpenAI key when setting model: {e}") return "⚠️ OpenAI API key is invalid." if provider == "nebius": if not nebius_key: return "⚠️ Please enter a valid Nebius API key before setting the model." try: _ = LLM_Client(session_id, sourceAI="nebius", api_key=nebius_key) except Exception as e: logger.warning(f"❌ Invalid Nebius key when setting model: {e}") return "⚠️ Nebius API key is invalid." if session_id not in session_keys: logger.warning(f"❌ Invalid Nebius key when setting model: {e}") session_keys[session_id] = {} session_keys[session_id]["provider"] = provider.capitalize() session_keys[session_id]["tool_call_model"] = tool_model session_keys[session_id]["response_model"] = response_model session_keys[session_id]["VLLM_model"] = vllm_model logger.info( f"✅ Models set for {session_id[:5]} | Provider: {provider}, Tool: {tool_model}, " f"Response: {response_model}, VLLM_model: {vllm_model}" ) return f"✅ Models saved for {session_id[:5]}" set_model_btn.click( fn=set_user_models, inputs=[session_id_state, provider_selector, tool_call_selector, response_model_selector, vllm_model_selector], outputs=[model_feedback] ).then( fn=check_keys_and_toggle_inputs, inputs=[session_id_state], outputs=[msg, send_btn, image_input, key_status] ) return demo if __name__ == "__main__": # folder_path = f"tmp" # if os.path.exists(folder_path): # shutil.rmtree(folder_path) # logger.info(f"Cleaned folder: {folder_path}") asyncio.get_event_loop().create_task(cleanup_old_sessions(threshold_seconds=600)) # 10 min interface = gradio_interface() interface.launch(debug=False)