import asyncio
import gradio as gr
import os
from config import session_keys
from mcp_client_wrapper import MCPClientWrapper
import logging
from utils import get_or_create_session, reset_session, decode_base64_image, cleanup_old_sessions
from custom_html_render import render_face_data_html
import asyncio
from llm import LLM_Client
import numpy as np
import shutil
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(f"😎 {__name__}") # Do not wonder why I uses emoj in logger - It is visually easier to track
def clear_data(sessionId):
folder_path = f"tmp/{sessionId}"
message = ""
try:
session_keys[sessionId] = {}
message += "API keys and model selection cleared! If you chat without adding new, there will be error."
except:
message += "Could not del API keys and model selection"
try:
if os.path.exists(folder_path):
shutil.rmtree(folder_path)
message += "Image Data and Color Analysis Data are cleared."
except:
message += "Failed to remove Image Data and Color Analysis Data"
return message
def show_images(*image_urls):
updates = []
for i in range(5):
try:
if i < len(image_urls):
updates.append(show_image(image_urls[i]))
else:
updates.append(gr.update(visible=False, value=None))
except Exception as e:
logger.warning(f"Failed to show image at index {i}: {e}")
updates.append(gr.update(visible=False, value=None))
return updates
def show_image(image_url):
if isinstance(image_url, np.ndarray):
return gr.update(value=image_url, visible=True)
if isinstance(image_url, str):
if not image_url.strip():
return gr.update(visible=True)
try:
img = f"https://ysharma-sanasprint.hf.space/gradio_api/file={image_url}"
return gr.update(value=img, visible=True)
except Exception as e:
logger.warning(f"Failed to decode image: {e}")
return gr.update(visible=True)
return gr.update(visible=True)
def check_keys_and_toggle_inputs(session_id):
settings = session_keys.get(session_id, {})
openai_key = settings.get("OPENAI_API_KEY")
nebius_key = settings.get("NEBIUS_API_KEY")
provider = settings.get("provider", "OpenAI")
if provider == "OpenAI":
if not openai_key:
return (
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
"⚠️ Please go to **Settings** and add your OpenAI API key to start."
)
try:
_ = LLM_Client(session_id, sourceAI="openai", api_key=openai_key)
except Exception as e:
return (
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
"⚠️ Please go to **Settings** and add your OpenAI API key to start."
)
if provider == "Nebius":
if not nebius_key:
return (
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
"⚠️ Please go to **Settings** and add your Nebius API key to start."
)
try:
_ = LLM_Client(session_id, sourceAI="nebius", api_key=nebius_key)
except Exception as e:
return (
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
"⚠️ Please go to **Settings** and add your Nebius API key to start."
)
return (
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
f"Using {provider} Provider"
)
def gradio_interface():
with gr.Blocks(title="StyleMatch Assistant – Find your colors. Find your style.",
css=".custom-image { height: 200px !important; width: 300px !important; object-fit: contain; } .custom-image2 { height: 250px !important; object-fit: contain; }") as demo:
session_id_state = gr.State(str(get_or_create_session()))
client_state = gr.State()
gr.Markdown("# StyleMatch Assistant – Find your colors. Find your style.")
def message_handler(session_id, message, history, image_input, client):
if client is None:
logger.warning(f"⚠️ No client found in state for session {session_id[:5]}")
return history + [{"role": "assistant", "content": "⚠️ Please save your API key first in Settings tab."}], gr.Textbox(value=""), *[None]*5, render_face_data_html({})
return client.process_message(session_id, message, history, image_input)
with gr.Tab("Chat"):
key_status = gr.Markdown(
"⚠️ Please add OpenAI API key to continue."
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
image_input = gr.Image(
label="Face Image",
visible=True,
interactive=False,
elem_classes="custom-image2",
type="filepath",
scale=2
)
with gr.Row(scale=3):
face_data_display = gr.HTML(label="Face Analysis Result",
value=render_face_data_html({}))
with gr.Column(scale=3):
chatbot = gr.Chatbot(
value=[],
height=500,
type="messages",
show_copy_button=True,
avatar_images=("asset/avatar.png", "asset/bot.png"),
scale=4
)
with gr.Row(equal_height=True):
msg = gr.Textbox(
label="What would you like to know?",
placeholder="What color do I match with?",
scale=2,
interactive=False,
)
send_btn = gr.Button("Send", variant="primary", size="sm", interactive=False)
image_outputs = []
with gr.Row(equal_height=True):
for i in range(5):
img = gr.Image(
label=f"Product Image {i+1}",
visible=True,
interactive=False,
elem_classes="custom-image"
)
image_outputs.append(img)
def bind_message_submission(trigger):
return trigger(
fn=message_handler,
inputs=[session_id_state, msg, chatbot, image_input, client_state],
outputs=[chatbot, msg] + image_outputs + [face_data_display]
).then(
fn=show_images,
inputs=image_outputs,
outputs=image_outputs
)
bind_message_submission(msg.submit)
bind_message_submission(send_btn.click)
with gr.Tab("Settings"):
clear_data_btn = gr.Button("Clear Data", variant="primary", size="sm", interactive=True)
data_clear_status = gr.Markdown("")
gr.Markdown("## 🔧 Model Settings")
provider_selector = gr.Dropdown(
label="Provider",
choices=["OpenAI", "Nebius"],
value="OpenAI"
)
tool_call_selector = gr.Dropdown(
label="Tool Call Model (OpenAI only)",
choices=["gpt-4o-mini", "gpt-4.1-mini"],
visible=True,
value="gpt-4o-mini"
)
response_model_selector = gr.Dropdown(
label="Response Model",
choices=["gpt-4o-mini", "gpt-4.1-mini"],
value="gpt-4o-mini"
)
vllm_model_selector = gr.Dropdown(
label="VLLM Model",
choices=["gpt-4o-mini", "gpt-4.1-mini"],
value="gpt-4o-mini"
)
model_feedback = gr.Markdown("")
set_model_btn = gr.Button("Set Models")
gr.Markdown("## 🔐 API Key Settings")
with gr.Column():
openai_key_input = gr.Textbox(
label="OpenAI API Key (Required)",
placeholder="Enter your OpenAI key",
type="password"
)
nebius_key_input = gr.Textbox(
label="Nebius API Key (Optional)",
placeholder="Enter your Nebius API Key",
type="password"
)
set_keys_btn = gr.Button("🔐 Save Keys")
key_feedback = gr.Markdown("")
clear_data_btn.click(fn=clear_data, inputs=session_id_state, outputs=data_clear_status)
def set_user_keys(session_id, provider, openai_key, nebius_key, tool_model, response_model, vllm_model):
openai_key = openai_key.strip()
nebius_key = nebius_key.strip()
provider = provider.strip()
if not openai_key: # OpenAI is a must
return "⚠️ Please provide a valid OpenAI API key."
try:
_ = LLM_Client(session_id, sourceAI="openai", api_key=openai_key)
except Exception as e:
logger.warning(f"❌ Invalid OpenAI key: {e}")
return "⚠️ OpenAI key is invalid."
if provider == "Nebius": #Only if Nebius is selected
if not nebius_key:
return "⚠️ Please provide a valid Nebius API key."
try:
_ = LLM_Client(session_id, sourceAI="nebius", api_key=nebius_key)
except Exception as e:
logger.warning(f"❌ Invalid Nebius key: {e}")
return "⚠️ Nebius key is invalid."
# Save only if validation passed
session_keys[session_id] = {
"OPENAI_API_KEY": openai_key,
"NEBIUS_API_KEY": nebius_key,
"provider": provider
}
session_keys[session_id]["provider"] = provider.capitalize()
session_keys[session_id]["tool_call_model"] = tool_model
session_keys[session_id]["response_model"] = response_model
session_keys[session_id]["VLLM_model"] = vllm_model
logger.info(f"✅ API keys set for session {session_id[:5]}")
return f"✅ Keys saved for session `{session_id[:5]}`."
def init_client_after_key_save(session_id):
client = MCPClientWrapper(session_id=session_id)
connect_msg = client.connect()
logger.info(connect_msg)
return client, f"✅ Client ready for session {session_id[:5]}"
set_keys_btn.click(
fn=set_user_keys,
inputs=[session_id_state, provider_selector, openai_key_input, nebius_key_input, tool_call_selector, response_model_selector, vllm_model_selector],
outputs=[key_feedback]
).then(
fn=check_keys_and_toggle_inputs,
inputs=[session_id_state],
outputs=[msg, send_btn, image_input, key_status]
).then(
fn=init_client_after_key_save,
inputs=[session_id_state],
outputs=[client_state, key_feedback]
)
def update_model_options(provider):
if provider == "OpenAI":
return (
gr.update(
choices=["gpt-4o-mini", "gpt-4.1-mini"],
value="gpt-4o-mini"
),
gr.update(
choices=["gpt-4o-mini", "gpt-4.1-mini"],
value="gpt-4o-mini"
),
gr.update(
choices=["gpt-4o-mini", "gpt-4.1-mini"],
value="gpt-4o-mini"
)
)
else:
return (
gr.update(
choices=["gpt-4o-mini", "gpt-4.1-mini"],
value="gpt-4o-mini"
),
gr.update(
choices=["mistralai/Mistral-Nemo-Instruct-2407"],
value="mistralai/Mistral-Nemo-Instruct-2407"
),
gr.update(
choices=["Qwen/Qwen2.5-VL-72B-Instruct"],
value="Qwen/Qwen2.5-VL-72B-Instruct"
)
)
provider_selector.change(
fn=lambda provider, session_id: (*update_model_options(provider), *check_keys_and_toggle_inputs(session_id)),
inputs=[provider_selector, session_id_state],
outputs=[
tool_call_selector,
response_model_selector,
vllm_model_selector,
msg,
send_btn,
image_input,
key_status
]
)
def set_user_models(session_id, provider, tool_model, response_model, vllm_model):
openai_key = session_keys.get(session_id, {}).get("OPENAI_API_KEY")
nebius_key = session_keys.get(session_id, {}).get("NEBIUS_API_KEY")
provider = provider.strip().lower()
if not openai_key:
return "⚠️ Please enter a valid OpenAI API key before setting the model."
try:
_ = LLM_Client(session_id, sourceAI="openai", api_key=openai_key)
except Exception as e:
logger.warning(f"❌ Invalid OpenAI key when setting model: {e}")
return "⚠️ OpenAI API key is invalid."
if provider == "nebius":
if not nebius_key:
return "⚠️ Please enter a valid Nebius API key before setting the model."
try:
_ = LLM_Client(session_id, sourceAI="nebius", api_key=nebius_key)
except Exception as e:
logger.warning(f"❌ Invalid Nebius key when setting model: {e}")
return "⚠️ Nebius API key is invalid."
if session_id not in session_keys:
logger.warning(f"❌ Invalid Nebius key when setting model: {e}")
session_keys[session_id] = {}
session_keys[session_id]["provider"] = provider.capitalize()
session_keys[session_id]["tool_call_model"] = tool_model
session_keys[session_id]["response_model"] = response_model
session_keys[session_id]["VLLM_model"] = vllm_model
logger.info(
f"✅ Models set for {session_id[:5]} | Provider: {provider}, Tool: {tool_model}, "
f"Response: {response_model}, VLLM_model: {vllm_model}"
)
return f"✅ Models saved for {session_id[:5]}"
set_model_btn.click(
fn=set_user_models,
inputs=[session_id_state, provider_selector, tool_call_selector, response_model_selector, vllm_model_selector],
outputs=[model_feedback]
).then(
fn=check_keys_and_toggle_inputs,
inputs=[session_id_state],
outputs=[msg, send_btn, image_input, key_status]
)
return demo
if __name__ == "__main__":
# folder_path = f"tmp"
# if os.path.exists(folder_path):
# shutil.rmtree(folder_path)
# logger.info(f"Cleaned folder: {folder_path}")
asyncio.get_event_loop().create_task(cleanup_old_sessions(threshold_seconds=600)) # 10 min
interface = gradio_interface()
interface.launch(debug=False)