import os import base64 import json from datetime import datetime import traceback from typing import Optional, Dict, Any import gradio as gr from huggingface_hub import HfApi, InferenceClient from fastmcp import FastMCP HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "OppaAI/Robot_MCP") HF_VLM_MODEL = os.environ.get("HF_VLM_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct") mcp = FastMCP("Robot_MCP_Server") # ------------------------------- # Upload helper # ------------------------------- def upload_image(image_b64: str, hf_token: str): try: image_bytes = base64.b64decode(image_b64) size_bytes = len(image_bytes) os.makedirs("/tmp", exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") local_path = f"/tmp/robot_img_{timestamp}.jpg" with open(local_path, "wb") as f: f.write(image_bytes) filename = f"robot_{timestamp}.jpg" api = HfApi() api.upload_file( path_or_fileobj=local_path, path_in_repo=f"tmp/{filename}", repo_id=HF_DATASET_REPO, repo_type="dataset", token=hf_token ) url = f"https://huggingface.co/datasets/{HF_DATASET_REPO}/resolve/main/tmp/{filename}" return local_path, url, filename, size_bytes except Exception: traceback.print_exc() return None, None, None, 0 # ------------------------------- # Safe JSON parse # ------------------------------- def safe_parse_json_from_text(text: str): if not text: return None try: return json.loads(text) except: pass cleaned = text.strip().strip("`").strip() if cleaned.lower().startswith("json"): cleaned = cleaned[4:].strip() try: start = cleaned.find("{") end = cleaned.rfind("}") return json.loads(cleaned[start:end + 1]) except: return None # ------------------------------- # TRUE CORE FUNCTION # ------------------------------- # ------------------------------- # TRUE CORE FUNCTION (with objects) # ------------------------------- def robot_watch_core(payload: Dict[str, Any]): if isinstance(payload, str): try: payload = json.loads(payload) except: return {"error": "Invalid JSON payload"} hf_token = payload.get("hf_token") if not hf_token: return {"error": "hf_token missing"} robot_id = payload.get("robot_id", "unknown") image_b64 = payload.get("image_b64") if not image_b64: return {"error": "image_b64 missing"} # Upload _, hf_url, _, size_bytes = upload_image(image_b64, hf_token) if not hf_url: return {"error": "Image upload failed"} # VLM system_prompt = """ Respond in STRICT JSON ONLY. { "description": "...", "human": "...", "environment": "...", "objects": [] // list of detected objects } """ messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": [ {"type": "text", "text": "Analyze the image."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}} ]} ] client = InferenceClient(token=hf_token) try: resp = client.chat.completions.create( model=HF_VLM_MODEL, messages=messages, max_tokens=500, temperature=0.1 ) except Exception as e: return {"status": "error", "message": str(e)} vlm_output = resp.choices[0].message.content.strip() parsed = safe_parse_json_from_text(vlm_output) if parsed is None: return { "status": "model_no_json", "vlm_raw": vlm_output, "message": "Invalid JSON returned" } # Ensure "objects" is a list objects = parsed.get("objects", []) if not isinstance(objects, list): objects = [] return { "status": "success", "robot_id": robot_id, "file_size_bytes": size_bytes, "image_url": hf_url, "description": parsed.get("description"), "human": parsed.get("human"), "environment": parsed.get("environment"), "objects": objects, # ← new field "vlm_raw": vlm_output } # ------------------------------- # REGISTER MCP TOOL (wrapper) # ------------------------------- @mcp.tool() def robot_watch(payload: Dict[str, Any]): return robot_watch_core(payload) # ------------------------------- # Gradio wrapper # ------------------------------- def process_json(payload): return robot_watch_core(payload) app = gr.Interface( fn=process_json, inputs=gr.JSON(label="Input JSON"), outputs=gr.JSON(label="Result JSON"), title="Robot MCP Server", description="JSON endpoint for robot vision pipeline.", api_name="predict" ) if __name__ == "__main__": app.launch(mcp_server=True)