OppaAI commited on
Commit
aca2800
·
verified ·
1 Parent(s): dceeed5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -39
app.py CHANGED
@@ -6,14 +6,17 @@ import gradio as gr
6
  from huggingface_hub import upload_file, InferenceClient
7
  from datetime import datetime
8
  import traceback
9
- import threading
10
  from typing import Optional, Dict, Any, Tuple
11
 
12
  from fastmcp import FastMCP
13
 
14
-
15
- HF_DATASET_REPO = "OppaAI/Robot_MCP"
16
- HF_VLM_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
 
 
 
 
17
 
18
  mcp = FastMCP("Robot_MCP")
19
 
@@ -22,6 +25,7 @@ mcp = FastMCP("Robot_MCP")
22
  # -----------------------------------------------------
23
  @mcp.tool()
24
  def speak(text: str, emotion: str = "neutral"):
 
25
  return {
26
  "status": "success",
27
  "action_executed": "speak",
@@ -31,6 +35,7 @@ def speak(text: str, emotion: str = "neutral"):
31
 
32
  @mcp.tool()
33
  def navigate(direction: str, distance_meters: float):
 
34
  if distance_meters > 5.0:
35
  return {"status": "error", "message": "Safety limit exceeded"}
36
  return {
@@ -42,6 +47,7 @@ def navigate(direction: str, distance_meters: float):
42
 
43
  @mcp.tool()
44
  def scan_hazard(hazard_type: str, severity: str):
 
45
  timestamp = datetime.now().isoformat()
46
  return {
47
  "status": "warning_logged",
@@ -51,6 +57,7 @@ def scan_hazard(hazard_type: str, severity: str):
51
 
52
  @mcp.tool()
53
  def analyze_human(clothing_color: str, estimated_action: str):
 
54
  return {
55
  "status": "human_tracked",
56
  "details": f"Human wearing {clothing_color} is {estimated_action}",
@@ -60,10 +67,13 @@ def analyze_human(clothing_color: str, estimated_action: str):
60
  # Save + Upload
61
  # -----------------------------------------------------
62
  def save_and_upload_image(image_b64: str, hf_token: str):
 
63
  try:
64
  image_bytes = base64.b64decode(image_b64)
65
  size_bytes = len(image_bytes)
66
 
 
 
67
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
68
  local_path = f"/tmp/robot_img_{timestamp}.jpg"
69
 
@@ -83,44 +93,47 @@ def save_and_upload_image(image_b64: str, hf_token: str):
83
  url = f"https://huggingface.co/datasets/{HF_DATASET_REPO}/resolve/main/{filename}"
84
  return local_path, url, filename, size_bytes
85
 
86
- except Exception:
 
87
  traceback.print_exc()
88
  return None, None, None, 0
89
 
90
  # -----------------------------------------------------
91
  # JSON Parse
92
  # -----------------------------------------------------
93
- def safe_parse_json_from_text(text: str):
 
94
  if not text:
95
  return None
96
  try:
97
  return json.loads(text)
98
- except:
99
- pass
 
 
 
 
 
100
 
101
- cleaned = text.strip().strip("`")
102
  try:
103
  start = cleaned.find("{")
104
  end = cleaned.rfind("}")
105
  if start >= 0 and end > start:
106
  return json.loads(cleaned[start : end + 1])
107
- except:
108
  pass
109
 
110
  return None
111
 
112
  # -----------------------------------------------------
113
- # FIXED: correct MCP tool registry access (v2)
114
  # -----------------------------------------------------
115
- def validate_and_call_tool(tool_name: str, tool_args: dict):
116
- # old: if tool_name not in mcp.tools:
117
- # ✔ new:
118
  if tool_name not in mcp._tools:
119
  return {"error": f"Unknown or unauthorized tool '{tool_name}'"}
120
 
121
  try:
122
- # ❌ old: mcp.tools[name](...)
123
- # ✔ new:
124
  tool_fn = mcp._tools[tool_name]["function"]
125
  return tool_fn(**tool_args)
126
 
@@ -131,39 +144,39 @@ def validate_and_call_tool(tool_name: str, tool_args: dict):
131
  # -----------------------------------------------------
132
  # Main Pipeline
133
  # -----------------------------------------------------
134
- def process_and_describe(payload):
135
-
 
 
136
  if isinstance(payload, str):
137
  try:
138
  payload = json.loads(payload)
139
- except:
140
- return {"error": "Invalid JSON payload"}
141
 
142
  hf_token = payload.get("hf_token")
143
  if not hf_token:
144
- return {"error": "hf_token missing"}
145
 
146
  robot_id = payload.get("robot_id", "unknown")
147
  image_b64 = payload.get("image_b64")
148
  if not image_b64:
149
- return {"error": "image_b64 missing"}
150
 
151
  # Save + Upload
152
- local_tmp_path, hf_url, filename, size_bytes = save_and_upload_image(
153
- image_b64, hf_token
154
- )
155
 
156
  if not hf_url:
157
  return {"error": "Image upload failed"}
158
 
159
  # VLM system prompt
160
- system_prompt = """
161
  Respond in STRICT JSON ONLY:
162
- {
163
  "description": "short visual description",
164
- "tool_name": "speak | navigate | scan_hazard | analyze_human",
165
- "arguments": { ... }
166
- }
167
  """
168
 
169
  messages = [
@@ -182,12 +195,16 @@ Respond in STRICT JSON ONLY:
182
 
183
  client = InferenceClient(token=hf_token)
184
 
185
- response = client.chat.completions.create(
186
- model=HF_VLM_MODEL,
187
- messages=messages,
188
- max_tokens=300,
189
- temperature=0.1,
190
- )
 
 
 
 
191
 
192
  vlm_output = response.choices[0].message.content.strip()
193
 
@@ -199,7 +216,7 @@ Respond in STRICT JSON ONLY:
199
  "robot_id": robot_id,
200
  "image_url": hf_url,
201
  "vlm_raw": vlm_output,
202
- "message": "VLM returned invalid JSON",
203
  }
204
 
205
  tool_name = parsed.get("tool_name")
@@ -224,8 +241,8 @@ Respond in STRICT JSON ONLY:
224
  # ------------------------------
225
  iface = gr.Interface(
226
  fn=process_and_describe,
227
- inputs=gr.JSON(label="Input JSON"),
228
- outputs=gr.JSON(label="Output JSON"),
229
  api_name="predict",
230
  flagging_mode="never"
231
  )
@@ -234,5 +251,8 @@ iface = gr.Interface(
234
  # Main Entry
235
  # ------------------------------
236
  if __name__ == "__main__":
 
 
237
  print("[Gradio] Launching interface...")
238
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
6
  from huggingface_hub import upload_file, InferenceClient
7
  from datetime import datetime
8
  import traceback
 
9
  from typing import Optional, Dict, Any, Tuple
10
 
11
  from fastmcp import FastMCP
12
 
13
+ # --- Configuration using Environment Variables ---
14
+ # It is best practice to manage sensitive info outside of the code.
15
+ # Use os.environ.get() to safely retrieve these values.
16
+ HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "OppaAI/Robot_MCP")
17
+ HF_VLM_MODEL = os.environ.get("HF_VLM_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct")
18
+ # The token will be required in the payload, but we define the env var name here.
19
+ # HF_TOKEN_ENV_VAR_NAME = "HF_TOKEN"
20
 
21
  mcp = FastMCP("Robot_MCP")
22
 
 
25
  # -----------------------------------------------------
26
  @mcp.tool()
27
  def speak(text: str, emotion: str = "neutral"):
28
+ """Makes the robot speak a given text with an emotion."""
29
  return {
30
  "status": "success",
31
  "action_executed": "speak",
 
35
 
36
  @mcp.tool()
37
  def navigate(direction: str, distance_meters: float):
38
+ """Moves the robot a specified distance in a direction (max 5m)."""
39
  if distance_meters > 5.0:
40
  return {"status": "error", "message": "Safety limit exceeded"}
41
  return {
 
47
 
48
  @mcp.tool()
49
  def scan_hazard(hazard_type: str, severity: str):
50
+ """Logs a potential hazard detected by the robot."""
51
  timestamp = datetime.now().isoformat()
52
  return {
53
  "status": "warning_logged",
 
57
 
58
  @mcp.tool()
59
  def analyze_human(clothing_color: str, estimated_action: str):
60
+ """Tracks human activity based on visual input."""
61
  return {
62
  "status": "human_tracked",
63
  "details": f"Human wearing {clothing_color} is {estimated_action}",
 
67
  # Save + Upload
68
  # -----------------------------------------------------
69
  def save_and_upload_image(image_b64: str, hf_token: str):
70
+ """Decodes a base64 image, saves it locally, and uploads to Hugging Face Hub."""
71
  try:
72
  image_bytes = base64.b64decode(image_b64)
73
  size_bytes = len(image_bytes)
74
 
75
+ # Ensure the /tmp directory exists
76
+ os.makedirs("/tmp", exist_ok=True)
77
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
78
  local_path = f"/tmp/robot_img_{timestamp}.jpg"
79
 
 
93
  url = f"https://huggingface.co/datasets/{HF_DATASET_REPO}/resolve/main/{filename}"
94
  return local_path, url, filename, size_bytes
95
 
96
+ except Exception as e:
97
+ print(f"Error during image upload: {e}")
98
  traceback.print_exc()
99
  return None, None, None, 0
100
 
101
  # -----------------------------------------------------
102
  # JSON Parse
103
  # -----------------------------------------------------
104
+ def safe_parse_json_from_text(text: str) -> Optional[Dict[str, Any]]:
105
+ """Attempts to safely parse JSON from potentially messy text output."""
106
  if not text:
107
  return None
108
  try:
109
  return json.loads(text)
110
+ except json.JSONDecodeError:
111
+ pass # Try heuristic approach
112
+
113
+ cleaned = text.strip().strip("`").strip()
114
+ # Remove leading 'json' if present after stripping backticks
115
+ if cleaned.lower().startswith("json"):
116
+ cleaned = cleaned[4:].strip()
117
 
 
118
  try:
119
  start = cleaned.find("{")
120
  end = cleaned.rfind("}")
121
  if start >= 0 and end > start:
122
  return json.loads(cleaned[start : end + 1])
123
+ except json.JSONDecodeError:
124
  pass
125
 
126
  return None
127
 
128
  # -----------------------------------------------------
129
+ # Validate and Call Tool
130
  # -----------------------------------------------------
131
+ def validate_and_call_tool(tool_name: str, tool_args: dict) -> Dict[str, Any]:
132
+ """Validates tool access and executes the corresponding function."""
 
133
  if tool_name not in mcp._tools:
134
  return {"error": f"Unknown or unauthorized tool '{tool_name}'"}
135
 
136
  try:
 
 
137
  tool_fn = mcp._tools[tool_name]["function"]
138
  return tool_fn(**tool_args)
139
 
 
144
  # -----------------------------------------------------
145
  # Main Pipeline
146
  # -----------------------------------------------------
147
+ def process_and_describe(payload: Dict[str, Any]) -> Dict[str, Any]:
148
+ """Main pipeline function to process image, call VLM, and execute tool."""
149
+
150
+ # Input handling for gradio.JSON input which sometimes arrives as a string
151
  if isinstance(payload, str):
152
  try:
153
  payload = json.loads(payload)
154
+ except json.JSONDecodeError:
155
+ return {"error": "Invalid JSON payload provided to the function"}
156
 
157
  hf_token = payload.get("hf_token")
158
  if not hf_token:
159
+ return {"error": "hf_token missing in payload. Cannot authenticate with HF Hub."}
160
 
161
  robot_id = payload.get("robot_id", "unknown")
162
  image_b64 = payload.get("image_b64")
163
  if not image_b64:
164
+ return {"error": "image_b64 missing in payload"}
165
 
166
  # Save + Upload
167
+ _, hf_url, _, size_bytes = save_and_upload_image(image_b64, hf_token)
 
 
168
 
169
  if not hf_url:
170
  return {"error": "Image upload failed"}
171
 
172
  # VLM system prompt
173
+ system_prompt = f"""
174
  Respond in STRICT JSON ONLY:
175
+ {{
176
  "description": "short visual description",
177
+ "tool_name": "{' | '.join(mcp._tools.keys())}",
178
+ "arguments": {{ ... }}
179
+ }}
180
  """
181
 
182
  messages = [
 
195
 
196
  client = InferenceClient(token=hf_token)
197
 
198
+ try:
199
+ response = client.chat.completions.create(
200
+ model=HF_VLM_MODEL,
201
+ messages=messages,
202
+ max_tokens=300,
203
+ temperature=0.1,
204
+ )
205
+ except Exception as e:
206
+ return {"status": "error", "message": f"Inference API call failed: {str(e)}"}
207
+
208
 
209
  vlm_output = response.choices[0].message.content.strip()
210
 
 
216
  "robot_id": robot_id,
217
  "image_url": hf_url,
218
  "vlm_raw": vlm_output,
219
+ "message": "VLM returned invalid JSON format",
220
  }
221
 
222
  tool_name = parsed.get("tool_name")
 
241
  # ------------------------------
242
  iface = gr.Interface(
243
  fn=process_and_describe,
244
+ inputs=gr.JSON(label="Input JSON Payload (must contain hf_token and image_b64)"),
245
+ outputs=gr.JSON(label="Output JSON Result"),
246
  api_name="predict",
247
  flagging_mode="never"
248
  )
 
251
  # Main Entry
252
  # ------------------------------
253
  if __name__ == "__main__":
254
+ print(f"[Config] HF_DATASET_REPO: {HF_DATASET_REPO}")
255
+ print(f"[Config] HF_VLM_MODEL: {HF_VLM_MODEL}")
256
  print("[Gradio] Launching interface...")
257
  iface.launch(server_name="0.0.0.0", server_port=7860)
258
+