OppaAI's picture
Update app.py
9ecd335 verified
raw
history blame
7.59 kB
# app.py
import os
import base64
import json
import gradio as gr
from huggingface_hub import upload_file, InferenceClient
from datetime import datetime
import traceback
from typing import Optional, Dict, Any
from fastmcp import FastMCP
# --- Configuration ---
HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "OppaAI/Robot_MCP")
HF_VLM_MODEL = os.environ.get("HF_VLM_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct")
mcp = FastMCP("Robot_MCP")
# -----------------------------------------------------
# Register Robot Tools (MCP)
# -----------------------------------------------------
@mcp.tool()
def speak(text: str, emotion: str = "neutral"):
"""Makes the robot speak a given text with an emotion."""
return {"status": "success", "action_executed": "speak", "payload": {"text": text, "emotion": emotion}}
@mcp.tool()
def navigate(direction: str, distance_meters: float):
"""Moves the robot a specified distance in a direction (max 5m)."""
if distance_meters > 5.0:
return {"status": "error", "message": "Safety limit exceeded"}
return {"status": "success", "action_executed": "navigate", "payload": {"direction": direction, "distance": distance_meters}}
@mcp.tool()
def scan_hazard(hazard_type: str, severity: str):
"""Logs a potential hazard detected by the robot."""
timestamp = datetime.now().isoformat()
return {"status": "warning_logged", "log": f"[{timestamp}] HAZARD: {hazard_type} (Severity: {severity})"}
@mcp.tool()
def analyze_human(clothing_color: str, estimated_action: str):
"""Tracks human activity based on visual input."""
return {"status": "human_tracked", "details": f"Human wearing {clothing_color} is {estimated_action}"}
# -----------------------------------------------------
# Save and upload image to HF
# -----------------------------------------------------
def save_and_upload_image(image_b64: str, hf_token: str):
"""Decodes a base64 image, saves it locally, and uploads to Hugging Face Hub."""
try:
image_bytes = base64.b64decode(image_b64)
size_bytes = len(image_bytes)
# Ensure the /tmp directory exists
os.makedirs("/tmp", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
local_path = f"/tmp/robot_img_{timestamp}.jpg"
with open(local_path, "wb") as f:
f.write(image_bytes)
filename = f"robot_{timestamp}.jpg"
# Corrected Hugging Face Hub upload
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=f"tmp/{filename}",
repo_id=HF_DATASET_REPO,
repo_type="dataset",
token=hf_token
)
url = f"https://huggingface.co/datasets/{HF_DATASET_REPO}/resolve/main/{filename}"
return local_path, url, filename, size_bytes
except Exception as e:
print(f"[Error] during image upload: {e}")
traceback.print_exc()
return None, None, None, 0
# -----------------------------------------------------
# JSON parsing helper
# -----------------------------------------------------
def safe_parse_json_from_text(text: str) -> Optional[Dict[str, Any]]:
"""Safely extract JSON from messy VLM output"""
if not text:
return None
try:
return json.loads(text)
except:
pass
cleaned = text.strip().strip("`").strip()
if cleaned.lower().startswith("json"):
cleaned = cleaned[4:].strip()
try:
start = cleaned.find("{")
end = cleaned.rfind("}")
if start >= 0 and end > start:
return json.loads(cleaned[start:end+1])
except:
return None
return None
# -----------------------------------------------------
# Call MCP tool safely using public API
# -----------------------------------------------------
def validate_and_call_tool(tool_name: str, tool_args: dict) -> Dict[str, Any]:
"""Use public API instead of _tools"""
try:
# FastMCP v2.11.2 provides call_tool
if hasattr(mcp, "call_tool"):
return mcp.call_tool(tool_name, tool_args)
# fallback: call the registered function directly
if hasattr(mcp, tool_name):
tool_fn = getattr(mcp, tool_name)
return tool_fn(**tool_args)
return {"error": f"Unknown tool '{tool_name}'"}
except Exception as e:
traceback.print_exc()
return {"error": f"Tool execution error: {str(e)}"}
# -----------------------------------------------------
# Main pipeline: image → VLM → tool
# -----------------------------------------------------
def process_and_describe(payload: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(payload, str):
try:
payload = json.loads(payload)
except:
return {"error": "Invalid JSON payload"}
hf_token = payload.get("hf_token")
if not hf_token:
return {"error": "hf_token missing"}
robot_id = payload.get("robot_id", "unknown")
image_b64 = payload.get("image_b64")
if not image_b64:
return {"error": "image_b64 missing"}
# Save + Upload
_, hf_url, _, size_bytes = save_and_upload_image(image_b64, hf_token)
if not hf_url:
return {"error": "Image upload failed"}
# VLM system prompt
system_prompt = f"""
Respond in STRICT JSON ONLY:
{{
"description": "short visual description",
"tool_name": "speak | navigate | scan_hazard | analyze_human",
"arguments": {{ ... }}
}}
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": [
{"type": "text", "text": "Analyze the image and choose ONE tool."},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}}
]}
]
client = InferenceClient(token=hf_token)
try:
response = client.chat.completions.create(
model=HF_VLM_MODEL,
messages=messages,
max_tokens=300,
temperature=0.1,
)
except Exception as e:
return {"status": "error", "message": f"Inference API call failed: {e}"}
vlm_output = response.choices[0].message.content.strip()
parsed = safe_parse_json_from_text(vlm_output)
if parsed is None:
return {"status": "model_no_json", "robot_id": robot_id, "image_url": hf_url, "vlm_raw": vlm_output, "message": "VLM returned invalid JSON"}
tool_name = parsed.get("tool_name")
tool_args = parsed.get("arguments") or {}
tool_result = validate_and_call_tool(tool_name, tool_args)
return {
"status": "success",
"robot_id": robot_id,
#"image_url": hf_url,
"file_size_bytes": size_bytes,
"vlm_description": parsed.get("description"),
"chosen_tool": tool_name,
"tool_arguments": tool_args,
"tool_execution_result": tool_result,
"vlm_raw": vlm_output
}
# ------------------------------
# Gradio Interface
# ------------------------------
iface = gr.Interface(
fn=process_and_describe,
inputs=gr.JSON(label="Input JSON Payload (must include hf_token & image_b64)"),
outputs=gr.JSON(label="Output JSON Result"),
api_name="predict",
flagging_mode="never"
)
# ------------------------------
# Main Entry
# ------------------------------
if __name__ == "__main__":
print(f"[Config] HF_DATASET_REPO: {HF_DATASET_REPO}")
print(f"[Config] HF_VLM_MODEL: {HF_VLM_MODEL}")
print("[Gradio] Launching interface...")
iface.launch(server_name="0.0.0.0", server_port=7860)