OppaAI's picture
Update app.py
aca2800 verified
raw
history blame
8.25 kB
# app.py
import os
import base64
import json
import gradio as gr
from huggingface_hub import upload_file, InferenceClient
from datetime import datetime
import traceback
from typing import Optional, Dict, Any, Tuple
from fastmcp import FastMCP
# --- Configuration using Environment Variables ---
# It is best practice to manage sensitive info outside of the code.
# Use os.environ.get() to safely retrieve these values.
HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "OppaAI/Robot_MCP")
HF_VLM_MODEL = os.environ.get("HF_VLM_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct")
# The token will be required in the payload, but we define the env var name here.
# HF_TOKEN_ENV_VAR_NAME = "HF_TOKEN"
mcp = FastMCP("Robot_MCP")
# -----------------------------------------------------
# Register Robot Tools (MCP)
# -----------------------------------------------------
@mcp.tool()
def speak(text: str, emotion: str = "neutral"):
"""Makes the robot speak a given text with an emotion."""
return {
"status": "success",
"action_executed": "speak",
"payload": {"text": text, "emotion": emotion},
}
@mcp.tool()
def navigate(direction: str, distance_meters: float):
"""Moves the robot a specified distance in a direction (max 5m)."""
if distance_meters > 5.0:
return {"status": "error", "message": "Safety limit exceeded"}
return {
"status": "success",
"action_executed": "navigate",
"payload": {"direction": direction, "distance": distance_meters},
}
@mcp.tool()
def scan_hazard(hazard_type: str, severity: str):
"""Logs a potential hazard detected by the robot."""
timestamp = datetime.now().isoformat()
return {
"status": "warning_logged",
"log": f"[{timestamp}] HAZARD: {hazard_type} (Severity: {severity})",
}
@mcp.tool()
def analyze_human(clothing_color: str, estimated_action: str):
"""Tracks human activity based on visual input."""
return {
"status": "human_tracked",
"details": f"Human wearing {clothing_color} is {estimated_action}",
}
# -----------------------------------------------------
# Save + Upload
# -----------------------------------------------------
def save_and_upload_image(image_b64: str, hf_token: str):
"""Decodes a base64 image, saves it locally, and uploads to Hugging Face Hub."""
try:
image_bytes = base64.b64decode(image_b64)
size_bytes = len(image_bytes)
# Ensure the /tmp directory exists
os.makedirs("/tmp", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
local_path = f"/tmp/robot_img_{timestamp}.jpg"
with open(local_path, "wb") as f:
f.write(image_bytes)
filename = f"robot_{timestamp}.jpg"
upload_file(
path_or_fileobj=local_path,
path_in_repo=filename,
repo_id=HF_DATASET_REPO,
token=hf_token,
repo_type="dataset",
)
url = f"https://huggingface.co/datasets/{HF_DATASET_REPO}/resolve/main/{filename}"
return local_path, url, filename, size_bytes
except Exception as e:
print(f"Error during image upload: {e}")
traceback.print_exc()
return None, None, None, 0
# -----------------------------------------------------
# JSON Parse
# -----------------------------------------------------
def safe_parse_json_from_text(text: str) -> Optional[Dict[str, Any]]:
"""Attempts to safely parse JSON from potentially messy text output."""
if not text:
return None
try:
return json.loads(text)
except json.JSONDecodeError:
pass # Try heuristic approach
cleaned = text.strip().strip("`").strip()
# Remove leading 'json' if present after stripping backticks
if cleaned.lower().startswith("json"):
cleaned = cleaned[4:].strip()
try:
start = cleaned.find("{")
end = cleaned.rfind("}")
if start >= 0 and end > start:
return json.loads(cleaned[start : end + 1])
except json.JSONDecodeError:
pass
return None
# -----------------------------------------------------
# Validate and Call Tool
# -----------------------------------------------------
def validate_and_call_tool(tool_name: str, tool_args: dict) -> Dict[str, Any]:
"""Validates tool access and executes the corresponding function."""
if tool_name not in mcp._tools:
return {"error": f"Unknown or unauthorized tool '{tool_name}'"}
try:
tool_fn = mcp._tools[tool_name]["function"]
return tool_fn(**tool_args)
except Exception as e:
traceback.print_exc()
return {"error": f"Tool error: {str(e)}"}
# -----------------------------------------------------
# Main Pipeline
# -----------------------------------------------------
def process_and_describe(payload: Dict[str, Any]) -> Dict[str, Any]:
"""Main pipeline function to process image, call VLM, and execute tool."""
# Input handling for gradio.JSON input which sometimes arrives as a string
if isinstance(payload, str):
try:
payload = json.loads(payload)
except json.JSONDecodeError:
return {"error": "Invalid JSON payload provided to the function"}
hf_token = payload.get("hf_token")
if not hf_token:
return {"error": "hf_token missing in payload. Cannot authenticate with HF Hub."}
robot_id = payload.get("robot_id", "unknown")
image_b64 = payload.get("image_b64")
if not image_b64:
return {"error": "image_b64 missing in payload"}
# Save + Upload
_, hf_url, _, size_bytes = save_and_upload_image(image_b64, hf_token)
if not hf_url:
return {"error": "Image upload failed"}
# VLM system prompt
system_prompt = f"""
Respond in STRICT JSON ONLY:
{{
"description": "short visual description",
"tool_name": "{' | '.join(mcp._tools.keys())}",
"arguments": {{ ... }}
}}
"""
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": "Analyze the image and choose ONE tool."},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_b64}"},
},
],
},
]
client = InferenceClient(token=hf_token)
try:
response = client.chat.completions.create(
model=HF_VLM_MODEL,
messages=messages,
max_tokens=300,
temperature=0.1,
)
except Exception as e:
return {"status": "error", "message": f"Inference API call failed: {str(e)}"}
vlm_output = response.choices[0].message.content.strip()
parsed = safe_parse_json_from_text(vlm_output)
if parsed is None:
return {
"status": "model_no_json",
"robot_id": robot_id,
"image_url": hf_url,
"vlm_raw": vlm_output,
"message": "VLM returned invalid JSON format",
}
tool_name = parsed.get("tool_name")
tool_args = parsed.get("arguments") or {}
tool_result = validate_and_call_tool(tool_name, tool_args)
return {
"status": "success",
"robot_id": robot_id,
"image_url": hf_url,
"file_size_bytes": size_bytes,
"vlm_description": parsed.get("description"),
"chosen_tool": tool_name,
"tool_arguments": tool_args,
"tool_execution_result": tool_result,
"vlm_raw": vlm_output,
}
# ------------------------------
# Gradio Interface
# ------------------------------
iface = gr.Interface(
fn=process_and_describe,
inputs=gr.JSON(label="Input JSON Payload (must contain hf_token and image_b64)"),
outputs=gr.JSON(label="Output JSON Result"),
api_name="predict",
flagging_mode="never"
)
# ------------------------------
# Main Entry
# ------------------------------
if __name__ == "__main__":
print(f"[Config] HF_DATASET_REPO: {HF_DATASET_REPO}")
print(f"[Config] HF_VLM_MODEL: {HF_VLM_MODEL}")
print("[Gradio] Launching interface...")
iface.launch(server_name="0.0.0.0", server_port=7860)