|
|
import asyncio |
|
|
import gradio as gr |
|
|
import os |
|
|
from config import session_keys |
|
|
from mcp_client_wrapper import MCPClientWrapper |
|
|
import logging |
|
|
from utils import get_or_create_session, reset_session, decode_base64_image, cleanup_old_sessions |
|
|
from custom_html_render import render_face_data_html |
|
|
import asyncio |
|
|
from llm import LLM_Client |
|
|
import numpy as np |
|
|
import shutil |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(f"😎 {__name__}") |
|
|
|
|
|
def clear_data(sessionId): |
|
|
folder_path = f"tmp/{sessionId}" |
|
|
message = "" |
|
|
try: |
|
|
session_keys[sessionId] = {} |
|
|
message += "API keys and model selection cleared! If you chat without adding new, there will be error." |
|
|
except: |
|
|
message += "Could not del API keys and model selection" |
|
|
try: |
|
|
if os.path.exists(folder_path): |
|
|
shutil.rmtree(folder_path) |
|
|
message += "Image Data and Color Analysis Data are cleared." |
|
|
except: |
|
|
message += "Failed to remove Image Data and Color Analysis Data" |
|
|
return message |
|
|
|
|
|
def show_images(*image_urls): |
|
|
updates = [] |
|
|
for i in range(5): |
|
|
try: |
|
|
if i < len(image_urls): |
|
|
updates.append(show_image(image_urls[i])) |
|
|
else: |
|
|
updates.append(gr.update(visible=False, value=None)) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to show image at index {i}: {e}") |
|
|
updates.append(gr.update(visible=False, value=None)) |
|
|
return updates |
|
|
|
|
|
|
|
|
def show_image(image_url): |
|
|
if isinstance(image_url, np.ndarray): |
|
|
return gr.update(value=image_url, visible=True) |
|
|
if isinstance(image_url, str): |
|
|
if not image_url.strip(): |
|
|
return gr.update(visible=True) |
|
|
try: |
|
|
img = f"https://ysharma-sanasprint.hf.space/gradio_api/file={image_url}" |
|
|
return gr.update(value=img, visible=True) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to decode image: {e}") |
|
|
return gr.update(visible=True) |
|
|
return gr.update(visible=True) |
|
|
|
|
|
def check_keys_and_toggle_inputs(session_id): |
|
|
settings = session_keys.get(session_id, {}) |
|
|
openai_key = settings.get("OPENAI_API_KEY") |
|
|
nebius_key = settings.get("NEBIUS_API_KEY") |
|
|
provider = settings.get("provider", "OpenAI") |
|
|
if provider == "OpenAI": |
|
|
if not openai_key: |
|
|
return ( |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
"⚠️ Please go to **Settings** and add your OpenAI API key to start." |
|
|
) |
|
|
try: |
|
|
_ = LLM_Client(session_id, sourceAI="openai", api_key=openai_key) |
|
|
except Exception as e: |
|
|
return ( |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
"⚠️ Please go to **Settings** and add your OpenAI API key to start." |
|
|
) |
|
|
if provider == "Nebius": |
|
|
if not nebius_key: |
|
|
return ( |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
"⚠️ Please go to **Settings** and add your Nebius API key to start." |
|
|
) |
|
|
try: |
|
|
_ = LLM_Client(session_id, sourceAI="nebius", api_key=nebius_key) |
|
|
except Exception as e: |
|
|
return ( |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
gr.update(interactive=False), |
|
|
"⚠️ Please go to **Settings** and add your Nebius API key to start." |
|
|
) |
|
|
return ( |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
f"Using {provider} Provider" |
|
|
) |
|
|
|
|
|
def gradio_interface(): |
|
|
with gr.Blocks(title="StyleMatch Assistant – Find your colors. Find your style.", |
|
|
css=".custom-image { height: 200px !important; width: 300px !important; object-fit: contain; } .custom-image2 { height: 250px !important; object-fit: contain; }") as demo: |
|
|
session_id_state = gr.State(str(get_or_create_session())) |
|
|
client_state = gr.State() |
|
|
|
|
|
gr.Markdown("# StyleMatch Assistant – Find your colors. Find your style.") |
|
|
|
|
|
def message_handler(session_id, message, history, image_input, client): |
|
|
if client is None: |
|
|
logger.warning(f"⚠️ No client found in state for session {session_id[:5]}") |
|
|
return history + [{"role": "assistant", "content": "⚠️ Please save your API key first in Settings tab."}], gr.Textbox(value=""), *[None]*5, render_face_data_html({}) |
|
|
|
|
|
return client.process_message(session_id, message, history, image_input) |
|
|
|
|
|
|
|
|
with gr.Tab("Chat"): |
|
|
key_status = gr.Markdown( |
|
|
"⚠️ <span style='color:orange; font-size: 18px; font-weight: bold;'>Please add OpenAI API key to continue.</span>" |
|
|
) |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image( |
|
|
label="Face Image", |
|
|
visible=True, |
|
|
interactive=False, |
|
|
elem_classes="custom-image2", |
|
|
type="filepath", |
|
|
scale=2 |
|
|
) |
|
|
with gr.Row(scale=3): |
|
|
face_data_display = gr.HTML(label="Face Analysis Result", |
|
|
value=render_face_data_html({})) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot( |
|
|
value=[], |
|
|
height=500, |
|
|
type="messages", |
|
|
show_copy_button=True, |
|
|
avatar_images=("asset/avatar.png", "asset/bot.png"), |
|
|
scale=4 |
|
|
) |
|
|
with gr.Row(equal_height=True): |
|
|
msg = gr.Textbox( |
|
|
label="What would you like to know?", |
|
|
placeholder="What color do I match with?", |
|
|
scale=2, |
|
|
interactive=False, |
|
|
) |
|
|
send_btn = gr.Button("Send", variant="primary", size="sm", interactive=False) |
|
|
|
|
|
image_outputs = [] |
|
|
with gr.Row(equal_height=True): |
|
|
for i in range(5): |
|
|
img = gr.Image( |
|
|
label=f"Product Image {i+1}", |
|
|
visible=True, |
|
|
interactive=False, |
|
|
elem_classes="custom-image" |
|
|
) |
|
|
image_outputs.append(img) |
|
|
|
|
|
def bind_message_submission(trigger): |
|
|
return trigger( |
|
|
fn=message_handler, |
|
|
inputs=[session_id_state, msg, chatbot, image_input, client_state], |
|
|
outputs=[chatbot, msg] + image_outputs + [face_data_display] |
|
|
).then( |
|
|
fn=show_images, |
|
|
inputs=image_outputs, |
|
|
outputs=image_outputs |
|
|
) |
|
|
|
|
|
bind_message_submission(msg.submit) |
|
|
bind_message_submission(send_btn.click) |
|
|
|
|
|
with gr.Tab("Settings"): |
|
|
clear_data_btn = gr.Button("Clear Data", variant="primary", size="sm", interactive=True) |
|
|
data_clear_status = gr.Markdown("") |
|
|
gr.Markdown("## 🔧 Model Settings") |
|
|
|
|
|
provider_selector = gr.Dropdown( |
|
|
label="Provider", |
|
|
choices=["OpenAI", "Nebius"], |
|
|
value="OpenAI" |
|
|
) |
|
|
|
|
|
tool_call_selector = gr.Dropdown( |
|
|
label="Tool Call Model (OpenAI only)", |
|
|
choices=["gpt-4o-mini", "gpt-4.1-mini"], |
|
|
visible=True, |
|
|
value="gpt-4o-mini" |
|
|
) |
|
|
|
|
|
response_model_selector = gr.Dropdown( |
|
|
label="Response Model", |
|
|
choices=["gpt-4o-mini", "gpt-4.1-mini"], |
|
|
value="gpt-4o-mini" |
|
|
) |
|
|
|
|
|
vllm_model_selector = gr.Dropdown( |
|
|
label="VLLM Model", |
|
|
choices=["gpt-4o-mini", "gpt-4.1-mini"], |
|
|
value="gpt-4o-mini" |
|
|
) |
|
|
|
|
|
model_feedback = gr.Markdown("") |
|
|
set_model_btn = gr.Button("Set Models") |
|
|
|
|
|
gr.Markdown("## 🔐 API Key Settings") |
|
|
|
|
|
with gr.Column(): |
|
|
openai_key_input = gr.Textbox( |
|
|
label="OpenAI API Key (Required)", |
|
|
placeholder="Enter your OpenAI key", |
|
|
type="password" |
|
|
) |
|
|
|
|
|
nebius_key_input = gr.Textbox( |
|
|
label="Nebius API Key (Optional)", |
|
|
placeholder="Enter your Nebius API Key", |
|
|
type="password" |
|
|
) |
|
|
|
|
|
set_keys_btn = gr.Button("🔐 Save Keys") |
|
|
key_feedback = gr.Markdown("") |
|
|
|
|
|
clear_data_btn.click(fn=clear_data, inputs=session_id_state, outputs=data_clear_status) |
|
|
|
|
|
|
|
|
def set_user_keys(session_id, provider, openai_key, nebius_key, tool_model, response_model, vllm_model): |
|
|
openai_key = openai_key.strip() |
|
|
nebius_key = nebius_key.strip() |
|
|
|
|
|
provider = provider.strip() |
|
|
|
|
|
if not openai_key: |
|
|
return "⚠️ <span style='color:orange'>Please provide a valid OpenAI API key.</span>" |
|
|
try: |
|
|
_ = LLM_Client(session_id, sourceAI="openai", api_key=openai_key) |
|
|
except Exception as e: |
|
|
logger.warning(f"❌ Invalid OpenAI key: {e}") |
|
|
return "⚠️ <span style='color:orange'>OpenAI key is invalid.</span>" |
|
|
|
|
|
if provider == "Nebius": |
|
|
if not nebius_key: |
|
|
return "⚠️ <span style='color:orange'>Please provide a valid Nebius API key.</span>" |
|
|
try: |
|
|
_ = LLM_Client(session_id, sourceAI="nebius", api_key=nebius_key) |
|
|
except Exception as e: |
|
|
logger.warning(f"❌ Invalid Nebius key: {e}") |
|
|
return "⚠️ <span style='color:orange'>Nebius key is invalid.</span>" |
|
|
|
|
|
|
|
|
|
|
|
session_keys[session_id] = { |
|
|
"OPENAI_API_KEY": openai_key, |
|
|
"NEBIUS_API_KEY": nebius_key, |
|
|
"provider": provider |
|
|
} |
|
|
session_keys[session_id]["provider"] = provider.capitalize() |
|
|
session_keys[session_id]["tool_call_model"] = tool_model |
|
|
session_keys[session_id]["response_model"] = response_model |
|
|
session_keys[session_id]["VLLM_model"] = vllm_model |
|
|
logger.info(f"✅ API keys set for session {session_id[:5]}") |
|
|
return f"✅ Keys saved for session `{session_id[:5]}`." |
|
|
|
|
|
|
|
|
def init_client_after_key_save(session_id): |
|
|
client = MCPClientWrapper(session_id=session_id) |
|
|
connect_msg = client.connect() |
|
|
logger.info(connect_msg) |
|
|
return client, f"✅ Client ready for session {session_id[:5]}" |
|
|
|
|
|
|
|
|
set_keys_btn.click( |
|
|
fn=set_user_keys, |
|
|
inputs=[session_id_state, provider_selector, openai_key_input, nebius_key_input, tool_call_selector, response_model_selector, vllm_model_selector], |
|
|
outputs=[key_feedback] |
|
|
).then( |
|
|
fn=check_keys_and_toggle_inputs, |
|
|
inputs=[session_id_state], |
|
|
outputs=[msg, send_btn, image_input, key_status] |
|
|
).then( |
|
|
fn=init_client_after_key_save, |
|
|
inputs=[session_id_state], |
|
|
outputs=[client_state, key_feedback] |
|
|
) |
|
|
|
|
|
|
|
|
def update_model_options(provider): |
|
|
if provider == "OpenAI": |
|
|
return ( |
|
|
gr.update( |
|
|
choices=["gpt-4o-mini", "gpt-4.1-mini"], |
|
|
value="gpt-4o-mini" |
|
|
), |
|
|
gr.update( |
|
|
choices=["gpt-4o-mini", "gpt-4.1-mini"], |
|
|
value="gpt-4o-mini" |
|
|
), |
|
|
gr.update( |
|
|
choices=["gpt-4o-mini", "gpt-4.1-mini"], |
|
|
value="gpt-4o-mini" |
|
|
) |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
gr.update( |
|
|
choices=["gpt-4o-mini", "gpt-4.1-mini"], |
|
|
value="gpt-4o-mini" |
|
|
), |
|
|
gr.update( |
|
|
choices=["mistralai/Mistral-Nemo-Instruct-2407"], |
|
|
value="mistralai/Mistral-Nemo-Instruct-2407" |
|
|
), |
|
|
gr.update( |
|
|
choices=["Qwen/Qwen2.5-VL-72B-Instruct"], |
|
|
value="Qwen/Qwen2.5-VL-72B-Instruct" |
|
|
) |
|
|
) |
|
|
provider_selector.change( |
|
|
fn=lambda provider, session_id: (*update_model_options(provider), *check_keys_and_toggle_inputs(session_id)), |
|
|
inputs=[provider_selector, session_id_state], |
|
|
outputs=[ |
|
|
tool_call_selector, |
|
|
response_model_selector, |
|
|
vllm_model_selector, |
|
|
msg, |
|
|
send_btn, |
|
|
image_input, |
|
|
key_status |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def set_user_models(session_id, provider, tool_model, response_model, vllm_model): |
|
|
openai_key = session_keys.get(session_id, {}).get("OPENAI_API_KEY") |
|
|
nebius_key = session_keys.get(session_id, {}).get("NEBIUS_API_KEY") |
|
|
provider = provider.strip().lower() |
|
|
|
|
|
if not openai_key: |
|
|
return "⚠️ <span style='color:orange'>Please enter a valid OpenAI API key before setting the model.</span>" |
|
|
try: |
|
|
_ = LLM_Client(session_id, sourceAI="openai", api_key=openai_key) |
|
|
except Exception as e: |
|
|
logger.warning(f"❌ Invalid OpenAI key when setting model: {e}") |
|
|
return "⚠️ <span style='color:orange'>OpenAI API key is invalid.</span>" |
|
|
|
|
|
if provider == "nebius": |
|
|
if not nebius_key: |
|
|
return "⚠️ <span style='color:orange'>Please enter a valid Nebius API key before setting the model.</span>" |
|
|
try: |
|
|
_ = LLM_Client(session_id, sourceAI="nebius", api_key=nebius_key) |
|
|
except Exception as e: |
|
|
logger.warning(f"❌ Invalid Nebius key when setting model: {e}") |
|
|
return "⚠️ <span style='color:orange'>Nebius API key is invalid.</span>" |
|
|
|
|
|
if session_id not in session_keys: |
|
|
logger.warning(f"❌ Invalid Nebius key when setting model: {e}") |
|
|
session_keys[session_id] = {} |
|
|
|
|
|
session_keys[session_id]["provider"] = provider.capitalize() |
|
|
session_keys[session_id]["tool_call_model"] = tool_model |
|
|
session_keys[session_id]["response_model"] = response_model |
|
|
session_keys[session_id]["VLLM_model"] = vllm_model |
|
|
|
|
|
logger.info( |
|
|
f"✅ Models set for {session_id[:5]} | Provider: {provider}, Tool: {tool_model}, " |
|
|
f"Response: {response_model}, VLLM_model: {vllm_model}" |
|
|
) |
|
|
return f"✅ Models saved for {session_id[:5]}" |
|
|
|
|
|
|
|
|
set_model_btn.click( |
|
|
fn=set_user_models, |
|
|
inputs=[session_id_state, provider_selector, tool_call_selector, response_model_selector, vllm_model_selector], |
|
|
outputs=[model_feedback] |
|
|
).then( |
|
|
fn=check_keys_and_toggle_inputs, |
|
|
inputs=[session_id_state], |
|
|
outputs=[msg, send_btn, image_input, key_status] |
|
|
) |
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
asyncio.get_event_loop().create_task(cleanup_old_sessions(threshold_seconds=600)) |
|
|
interface = gradio_interface() |
|
|
interface.launch(debug=False) |
|
|
|