Spaces:
Sleeping
Sleeping
| import asyncio | |
| from typing import List, Dict, Any, Union | |
| from contextlib import AsyncExitStack | |
| import json | |
| import gradio as gr | |
| from gradio.components.chatbot import ChatMessage | |
| from mcp import ClientSession, StdioServerParameters | |
| from mcp.client.stdio import stdio_client | |
| from custom_html_render import render_face_data_html | |
| import logging | |
| from utils import save_uploaded_image, encode_image_to_base64, encrypt_session_id, JsonFileHandle, mark_session_active, save_encrypted_session_keys | |
| import os | |
| from config import session_keys | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(f"🤩 {__name__}") | |
| from llm_node import LLMNode # ✅ USE THIS NOW | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import numpy as np | |
| class MCPClientWrapper: | |
| def __init__(self, session_id=None): | |
| self.session_id = session_id | |
| self.session = None | |
| self.exit_stack = None | |
| self.tools = [] | |
| def connect(self) -> str: | |
| server_path = "gradio_mcp_server.py" | |
| return loop.run_until_complete(self._connect(server_path)) | |
| async def _connect(self, server_path: str) -> str: | |
| if self.exit_stack: | |
| await self.exit_stack.aclose() | |
| self.exit_stack = AsyncExitStack() | |
| is_python = server_path.endswith('.py') | |
| command = "python" if is_python else "node" | |
| server_params = StdioServerParameters( | |
| command=command, | |
| args=[server_path], | |
| env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"} | |
| ) | |
| stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) | |
| self.stdio, self.write = stdio_transport | |
| self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) | |
| init_response = await self.session.initialize() | |
| response = await self.session.list_tools() | |
| self.tools = [{ | |
| "type": "function", | |
| "function": { | |
| "name": tool.name, | |
| "description": tool.description, | |
| "parameters": tool.inputSchema | |
| } | |
| } for tool in response.tools] | |
| tool_names = [tool["function"]["name"] for tool in self.tools] | |
| return f"Connected to MCP server. Available tools: {', '.join(tool_names)}" | |
| def process_message(self, session_id, message: str, history: List[Union[Dict[str, Any], ChatMessage]], image_input=None) -> tuple: | |
| if not self.session: | |
| return history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": "Please connect to an MCP server first."} | |
| ], gr.Textbox(value=""), [None, None, None, None, None], render_face_data_html({}) | |
| mark_session_active(session_id) # Update Session | |
| self.llm_node = LLMNode(session_id=session_id) | |
| image_base64 = None | |
| image = None | |
| user_data_path = f"tmp/{session_id}" | |
| save_encrypted_session_keys(session_id, session_keys[session_id]) | |
| if image_input is not None: | |
| try: | |
| if isinstance(image_input, str): | |
| image = Image.open(image_input) | |
| elif isinstance(image_input, np.ndarray): | |
| image = Image.fromarray(image_input.astype("uint8")) | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input | |
| # Save image with session ID | |
| if image: | |
| image_name = save_uploaded_image(f"face", image, user_data_path) | |
| image_base64 = encode_image_to_base64(image) | |
| logger.info("✅ Input image saved and converted to base64") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to handle input image: {e}") | |
| user_data = JsonFileHandle.load_json_data("user_data", user_data_path) | |
| # Send to LLM | |
| new_messages, image_url, new_face_data, new_color_season, product_images = loop.run_until_complete( | |
| self._process_query( | |
| message=message, | |
| history=history, | |
| image_base64=image_base64, | |
| face_data=user_data["FaceData"] if user_data.get("FaceData") else None, | |
| color_season=user_data["ColorSeason"] if user_data.get("ColorSeason") else None, | |
| encryptId=encrypt_session_id(session_id) | |
| ) | |
| ) | |
| # Fallback face_data structure if none found | |
| data_update = False | |
| if user_data.get("FaceData") is None and new_face_data: | |
| user_data["FaceData"] = new_face_data | |
| data_update = True | |
| if user_data.get("ColorSeason") is None and new_color_season: | |
| user_data["ColorSeason"] = new_color_season | |
| data_update = True | |
| if data_update: | |
| JsonFileHandle.save_json_data("user_data", user_data, user_data_path) | |
| html_display = render_face_data_html(new_face_data) | |
| while len(product_images) < 5: | |
| product_images.append(None) | |
| product_images = product_images[:5] | |
| return history + [{"role": "user", "content": message}] + new_messages, gr.Textbox(value=""), *product_images, html_display | |
| async def _process_query( | |
| self, | |
| message: str, | |
| history: List[Union[Dict[str, Any], ChatMessage]], | |
| image_base64: str = None, | |
| face_data: dict = None, | |
| color_season: dict = None, | |
| encryptId: str = None | |
| ): | |
| logger.info(f"Image Exist: {True if image_base64 else False}") | |
| # Run first step: tool suggestion | |
| if image_base64: | |
| messages = self.llm_node.build_prompt(history, message, image_base64 =image_base64, vision_enabled=True, | |
| type="toolcall", encryptId=encryptId, | |
| history_len = 10, face_data = face_data, color_season = color_season) | |
| else: | |
| messages = self.llm_node.build_prompt(history, message, vision_enabled=True, | |
| type="toolcall", history_len = 10, | |
| face_data = face_data, color_season = color_season) | |
| step1 = self.llm_node.call_tool_step(messages, self.tools) | |
| choice = step1["choices"][0] | |
| tool_calls = choice["message"].get("tool_calls", []) | |
| logger.info(f"Tool Called: {tool_calls}") | |
| image_url = None | |
| result_messages = [] | |
| product_images = [] | |
| if not tool_calls: | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": choice["message"]["content"] | |
| }) | |
| return result_messages, image_url, face_data, color_season, product_images | |
| tool = tool_calls[0] #just use 1 tool a time for now | |
| tool_name = tool["function"]["name"] | |
| tool_args_json = tool["function"]["arguments"] | |
| try: | |
| tool_args = json.loads(tool_args_json) | |
| except Exception: | |
| tool_args = {} | |
| # 🛠️ Call the actual tool via MCP | |
| result = await self.session.call_tool(tool_name, tool_args) | |
| result_content = result.content | |
| logger.info(f"Tool Called result: {result_content}") | |
| if isinstance(result_content, list): | |
| result_content = "\n".join(str(item.text) for item in result_content) | |
| result_json = None | |
| try: | |
| if isinstance(result_content, dict): | |
| result_json = result_content | |
| else: | |
| result_json = json.loads(result_content) | |
| except Exception as e: | |
| step2_messages = self.llm_node.call_generation_step( | |
| message=message, | |
| history=history, | |
| face_data=face_data, | |
| color_season=color_season | |
| ) | |
| result_messages.extend(step2_messages) | |
| if result_json: | |
| if isinstance(result_json, dict) and "type" in result_json: | |
| if result_json["type"] == "product_list": | |
| products = result_json["products"] | |
| logger.info(f"Products: {products}") | |
| result_content = "\n\n".join( | |
| f"{row['name']} ({row['color']}, {row['season']} Wear, {row['usage']}), {row['image_url']}" | |
| for row in products | |
| ) | |
| for row in products: | |
| if isinstance(row, dict) and "image_url" in row: | |
| product_images.append(row["image_url"]) | |
| step2_messages = self.llm_node.call_generation_step( | |
| message=message, | |
| history=history, | |
| tool_result=result_content, | |
| face_data=face_data, | |
| color_season=color_season | |
| ) | |
| result_messages.extend(step2_messages) | |
| if result_json["type"] == "FaceData": | |
| face_data = result_json["FaceData"] | |
| color_season = result_json["ColorSeason"] | |
| step2_messages = self.llm_node.call_generation_step( | |
| message=message, | |
| history=history, | |
| tool_result=result_content, | |
| face_data=face_data, | |
| color_season=color_season | |
| ) | |
| result_messages.extend(step2_messages) | |
| if result_json["type"] == "image": | |
| if "url" in result_json: | |
| image_url = result_json["url"] | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": f"You can view and download it on the right.", | |
| }) | |
| elif result_json["type"] == "text": | |
| # 🧠 Run second step using Nebius to finalize the response | |
| step2_messages = self.llm_node.call_generation_step( | |
| message=message, | |
| history=history, | |
| tool_result=result_content, | |
| face_data=face_data, | |
| color_season=color_season | |
| ) | |
| result_messages.extend(step2_messages) | |
| return result_messages, image_url, face_data, color_season, product_images | |