SalexAI commited on
Commit
8033e09
·
verified ·
1 Parent(s): cd032a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -79
app.py CHANGED
@@ -1,16 +1,42 @@
1
- from fastapi import FastAPI, HTTPException, Header
2
- from pydantic import BaseModel, Field
3
- from typing import Optional, Dict, Any, List
4
  from uuid import uuid4
5
  from datetime import datetime, timezone
6
- import asyncio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- app = FastAPI(title="Create3 Robot Command Bridge", version="1.0.0")
9
 
10
- # ====== Simple shared secret auth (optional but recommended) ======
11
- API_TOKEN = "change-me-please" # set this to your own token
12
 
13
- def check_token(auth_header: Optional[str]):
 
14
  if not API_TOKEN:
15
  return
16
  if not auth_header or not auth_header.startswith("Bearer "):
@@ -19,30 +45,6 @@ def check_token(auth_header: Optional[str]):
19
  if token != API_TOKEN:
20
  raise HTTPException(status_code=403, detail="Invalid token")
21
 
22
- # ====== In-memory command store ======
23
- # For HF Space demo/testing this is OK.
24
- # If you restart the app, queue clears.
25
- commands: Dict[str, Dict[str, Any]] = {}
26
- queue: List[str] = []
27
- queue_lock = asyncio.Lock()
28
-
29
- # ====== Models ======
30
- class CommandCreate(BaseModel):
31
- command: str = Field(..., description="Robot command name")
32
- args: Dict[str, Any] = Field(default_factory=dict, description="Command arguments")
33
- source: Optional[str] = Field(default="ai", description="Who sent the command")
34
- priority: int = Field(default=0, description="Higher = earlier in queue")
35
- ttl_seconds: Optional[int] = Field(default=120, description="Optional expiration")
36
-
37
- class CommandStatusUpdate(BaseModel):
38
- status: str = Field(..., description="queued | running | done | failed | expired")
39
- result: Optional[Dict[str, Any]] = None
40
- error: Optional[str] = None
41
- robot_id: Optional[str] = None
42
-
43
- # ====== Helpers ======
44
- def now_iso() -> str:
45
- return datetime.now(timezone.utc).isoformat()
46
 
47
  def command_expired(cmd: Dict[str, Any]) -> bool:
48
  ttl = cmd.get("ttl_seconds")
@@ -52,74 +54,270 @@ def command_expired(cmd: Dict[str, Any]) -> bool:
52
  age = (datetime.now(timezone.utc) - created).total_seconds()
53
  return age > ttl
54
 
55
- # ====== API ======
56
-
57
- @app.get("/")
58
- async def root():
59
- return {"ok": True, "service": "create3-command-bridge"}
60
-
61
- @app.post("/commands")
62
- async def create_command(cmd: CommandCreate, authorization: Optional[str] = Header(default=None)):
63
- check_token(authorization)
64
 
 
 
 
 
 
 
 
 
65
  cmd_id = str(uuid4())
66
  record = {
67
  "id": cmd_id,
68
- "command": cmd.command,
69
- "args": cmd.args,
70
- "source": cmd.source,
71
- "priority": cmd.priority,
72
- "ttl_seconds": cmd.ttl_seconds,
73
  "status": "queued",
74
  "result": None,
75
  "error": None,
76
  "created_at": now_iso(),
77
  "updated_at": now_iso(),
78
  "claimed_by": None,
 
79
  }
80
 
81
  async with queue_lock:
82
  commands[cmd_id] = record
83
 
84
- # Insert by priority (higher first)
85
  inserted = False
86
  for i, queued_id in enumerate(queue):
87
- if record["priority"] > commands[queued_id]["priority"]:
 
 
 
88
  queue.insert(i, cmd_id)
89
  inserted = True
90
  break
91
  if not inserted:
92
  queue.append(cmd_id)
93
 
94
- return {"ok": True, "command_id": cmd_id, "status": "queued"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  @app.get("/commands")
97
  async def list_commands(limit: int = 20, authorization: Optional[str] = Header(default=None)):
98
- check_token(authorization)
99
- # latest updated first
100
- items = sorted(commands.values(), key=lambda x: x["updated_at"], reverse=True)[:limit]
101
  return {"ok": True, "items": items}
102
 
 
103
  @app.get("/commands/{command_id}")
104
- async def get_command(command_id: str, authorization: Optional[str] = Header(default=None)):
105
- check_token(authorization)
106
  if command_id not in commands:
107
  raise HTTPException(status_code=404, detail="Command not found")
108
  return {"ok": True, "item": commands[command_id]}
109
 
 
110
  @app.post("/commands/next")
111
- async def claim_next_command(
112
- robot_id: str,
113
- authorization: Optional[str] = Header(default=None),
114
- ):
115
  """
116
- Client polls this endpoint to claim the next command.
117
- Returns one command and marks it running.
118
  """
119
- check_token(authorization)
120
 
121
  async with queue_lock:
122
- # Clean expired queued commands while scanning
123
  while queue:
124
  cmd_id = queue.pop(0)
125
  cmd = commands.get(cmd_id)
@@ -141,58 +339,53 @@ async def claim_next_command(
141
 
142
  return {"ok": True, "item": None}
143
 
 
144
  @app.post("/commands/{command_id}/status")
145
  async def update_command_status(
146
  command_id: str,
147
  update: CommandStatusUpdate,
148
  authorization: Optional[str] = Header(default=None),
149
  ):
150
- check_token(authorization)
151
 
152
  cmd = commands.get(command_id)
153
  if not cmd:
154
  raise HTTPException(status_code=404, detail="Command not found")
155
 
156
- # Basic state update
157
  cmd["status"] = update.status
158
  cmd["result"] = update.result
159
  cmd["error"] = update.error
160
  if update.robot_id:
161
  cmd["claimed_by"] = update.robot_id
162
  cmd["updated_at"] = now_iso()
163
-
164
  return {"ok": True}
165
 
 
166
  @app.post("/commands/{command_id}/cancel")
167
  async def cancel_command(command_id: str, authorization: Optional[str] = Header(default=None)):
168
- check_token(authorization)
169
 
170
  cmd = commands.get(command_id)
171
  if not cmd:
172
  raise HTTPException(status_code=404, detail="Command not found")
173
 
174
  if cmd["status"] in ("done", "failed", "expired"):
175
- return {"ok": True, "status": cmd["status"], "message": "Already finished"}
176
 
177
  cmd["status"] = "failed"
178
  cmd["error"] = "Cancelled by operator"
179
  cmd["updated_at"] = now_iso()
180
 
181
- # Remove from queue if still queued
182
  async with queue_lock:
183
  if command_id in queue:
184
  queue.remove(command_id)
185
 
186
  return {"ok": True, "status": "failed"}
187
 
188
- @app.get("/health")
189
- async def health():
190
- queued = sum(1 for c in commands.values() if c["status"] == "queued")
191
- running = sum(1 for c in commands.values() if c["status"] == "running")
192
- return {
193
- "ok": True,
194
- "queued": queued,
195
- "running": running,
196
- "total": len(commands),
197
- "time": now_iso(),
198
- }
 
1
+ import os
2
+ import asyncio
 
3
  from uuid import uuid4
4
  from datetime import datetime, timezone
5
+ from typing import Optional, Dict, Any, List
6
+
7
+ from fastapi import FastAPI, HTTPException, Header, Request
8
+ from fastapi.responses import JSONResponse
9
+ from pydantic import BaseModel, Field
10
+ from starlette.routing import Mount
11
+
12
+ # MCP (official Python SDK)
13
+ from mcp.server.fastmcp import FastMCP, Context
14
+
15
+ # ============================================================
16
+ # CONFIG
17
+ # ============================================================
18
+ API_TOKEN = os.getenv("ROBOT_API_TOKEN", "change-me-please")
19
+ ALLOWED_ORIGINS = {
20
+ # Add your domains here (HF space URL, localhost during testing, etc.)
21
+ "https://your-space-name.hf.space",
22
+ "http://localhost:7860",
23
+ "http://127.0.0.1:7860",
24
+ }
25
+
26
+ # ============================================================
27
+ # IN-MEMORY COMMAND STORE
28
+ # (Good for MVP / HF Space demos. Use Redis later if you want persistence.)
29
+ # ============================================================
30
+ commands: Dict[str, Dict[str, Any]] = {}
31
+ queue: List[str] = []
32
+ queue_lock = asyncio.Lock()
33
 
 
34
 
35
+ def now_iso() -> str:
36
+ return datetime.now(timezone.utc).isoformat()
37
 
38
+
39
+ def check_bearer(auth_header: Optional[str]):
40
  if not API_TOKEN:
41
  return
42
  if not auth_header or not auth_header.startswith("Bearer "):
 
45
  if token != API_TOKEN:
46
  raise HTTPException(status_code=403, detail="Invalid token")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def command_expired(cmd: Dict[str, Any]) -> bool:
50
  ttl = cmd.get("ttl_seconds")
 
54
  age = (datetime.now(timezone.utc) - created).total_seconds()
55
  return age > ttl
56
 
 
 
 
 
 
 
 
 
 
57
 
58
+ async def enqueue_command(
59
+ command: str,
60
+ args: Optional[Dict[str, Any]] = None,
61
+ source: str = "mcp",
62
+ priority: int = 0,
63
+ ttl_seconds: int = 120,
64
+ meta: Optional[Dict[str, Any]] = None,
65
+ ) -> Dict[str, Any]:
66
  cmd_id = str(uuid4())
67
  record = {
68
  "id": cmd_id,
69
+ "command": command,
70
+ "args": args or {},
71
+ "source": source,
72
+ "priority": priority,
73
+ "ttl_seconds": ttl_seconds,
74
  "status": "queued",
75
  "result": None,
76
  "error": None,
77
  "created_at": now_iso(),
78
  "updated_at": now_iso(),
79
  "claimed_by": None,
80
+ "meta": meta or {},
81
  }
82
 
83
  async with queue_lock:
84
  commands[cmd_id] = record
85
 
86
+ # priority insert (higher first)
87
  inserted = False
88
  for i, queued_id in enumerate(queue):
89
+ q = commands.get(queued_id)
90
+ if not q:
91
+ continue
92
+ if record["priority"] > q.get("priority", 0):
93
  queue.insert(i, cmd_id)
94
  inserted = True
95
  break
96
  if not inserted:
97
  queue.append(cmd_id)
98
 
99
+ return record
100
+
101
+
102
+ # ============================================================
103
+ # FASTAPI MODELS (robot runner REST side)
104
+ # ============================================================
105
+ class CommandStatusUpdate(BaseModel):
106
+ status: str = Field(..., description="queued | running | done | failed | expired")
107
+ result: Optional[Dict[str, Any]] = None
108
+ error: Optional[str] = None
109
+ robot_id: Optional[str] = None
110
+
111
+
112
+ class CommandCreateREST(BaseModel):
113
+ command: str
114
+ args: Dict[str, Any] = Field(default_factory=dict)
115
+ source: str = "api"
116
+ priority: int = 0
117
+ ttl_seconds: int = 120
118
+
119
+
120
+ # ============================================================
121
+ # MCP SERVER (tools for Claude/ChatGPT/etc)
122
+ # ============================================================
123
+ mcp = FastMCP(
124
+ "Create3 Robot Bridge",
125
+ instructions=(
126
+ "Send safe, short Create 3 robot commands by enqueueing them for a local robot runner. "
127
+ "This server does not execute commands directly; it queues them for a trusted client."
128
+ ),
129
+ stateless_http=True,
130
+ json_response=True,
131
+ )
132
+
133
+ # Optional: simple helper tool
134
+ @mcp.tool()
135
+ async def robot_status_summary(ctx: Context) -> dict:
136
+ """Get queue and command status summary."""
137
+ queued = sum(1 for c in commands.values() if c["status"] == "queued")
138
+ running = sum(1 for c in commands.values() if c["status"] == "running")
139
+ done = sum(1 for c in commands.values() if c["status"] == "done")
140
+ failed = sum(1 for c in commands.values() if c["status"] == "failed")
141
+ expired = sum(1 for c in commands.values() if c["status"] == "expired")
142
+ return {
143
+ "ok": True,
144
+ "server": "Create3 Robot Bridge",
145
+ "queued": queued,
146
+ "running": running,
147
+ "done": done,
148
+ "failed": failed,
149
+ "expired": expired,
150
+ "total": len(commands),
151
+ "time": now_iso(),
152
+ }
153
+
154
+
155
+ @mcp.tool()
156
+ async def list_recent_commands(limit: int = 20, ctx: Optional[Context] = None) -> dict:
157
+ """List recent robot commands and their status."""
158
+ items = sorted(commands.values(), key=lambda x: x["updated_at"], reverse=True)[: max(1, min(limit, 100))]
159
+ return {"ok": True, "items": items}
160
+
161
+
162
+ @mcp.tool()
163
+ async def get_command(command_id: str, ctx: Optional[Context] = None) -> dict:
164
+ """Get the status of one command by ID."""
165
+ cmd = commands.get(command_id)
166
+ if not cmd:
167
+ return {"ok": False, "error": "Command not found", "command_id": command_id}
168
+ return {"ok": True, "item": cmd}
169
+
170
+
171
+ # ---- Robot command tools (enqueue only) ----
172
+ @mcp.tool()
173
+ async def undock(priority: int = 0, ttl_seconds: int = 120, ctx: Optional[Context] = None) -> dict:
174
+ """Undock the Create 3 from its dock."""
175
+ rec = await enqueue_command("undock", {}, source="mcp", priority=priority, ttl_seconds=ttl_seconds)
176
+ return {"ok": True, "queued": rec}
177
+
178
+
179
+ @mcp.tool()
180
+ async def dock(priority: int = 0, ttl_seconds: int = 120, ctx: Optional[Context] = None) -> dict:
181
+ """Dock the Create 3 to its dock."""
182
+ rec = await enqueue_command("dock", {}, source="mcp", priority=priority, ttl_seconds=ttl_seconds)
183
+ return {"ok": True, "queued": rec}
184
+
185
+
186
+ @mcp.tool()
187
+ async def move_cm(cm: float, priority: int = 0, ttl_seconds: int = 120, ctx: Optional[Context] = None) -> dict:
188
+ """Move forward (or backward if negative) in centimeters."""
189
+ rec = await enqueue_command("move_cm", {"cm": cm}, source="mcp", priority=priority, ttl_seconds=ttl_seconds)
190
+ return {"ok": True, "queued": rec}
191
+
192
+
193
+ @mcp.tool()
194
+ async def turn_left(deg: float = 90, priority: int = 0, ttl_seconds: int = 120, ctx: Optional[Context] = None) -> dict:
195
+ """Turn the robot left by degrees."""
196
+ rec = await enqueue_command("turn_left", {"deg": deg}, source="mcp", priority=priority, ttl_seconds=ttl_seconds)
197
+ return {"ok": True, "queued": rec}
198
+
199
+
200
+ @mcp.tool()
201
+ async def turn_right(deg: float = 90, priority: int = 0, ttl_seconds: int = 120, ctx: Optional[Context] = None) -> dict:
202
+ """Turn the robot right by degrees."""
203
+ rec = await enqueue_command("turn_right", {"deg": deg}, source="mcp", priority=priority, ttl_seconds=ttl_seconds)
204
+ return {"ok": True, "queued": rec}
205
+
206
+
207
+ @mcp.tool()
208
+ async def drive(left: float, right: float, seconds: float = 1.0, priority: int = 0, ttl_seconds: int = 120, ctx: Optional[Context] = None) -> dict:
209
+ """Direct wheel drive command. Use small values for safety."""
210
+ rec = await enqueue_command(
211
+ "drive",
212
+ {"left": left, "right": right, "seconds": seconds},
213
+ source="mcp",
214
+ priority=priority,
215
+ ttl_seconds=ttl_seconds,
216
+ )
217
+ return {"ok": True, "queued": rec}
218
+
219
+
220
+ @mcp.tool()
221
+ async def stop(priority: int = 100, ttl_seconds: int = 30, ctx: Optional[Context] = None) -> dict:
222
+ """Emergency stop. High priority."""
223
+ rec = await enqueue_command("stop", {}, source="mcp", priority=priority, ttl_seconds=ttl_seconds)
224
+ return {"ok": True, "queued": rec}
225
+
226
+
227
+ @mcp.tool()
228
+ async def say(text: str, priority: int = 0, ttl_seconds: int = 60, ctx: Optional[Context] = None) -> dict:
229
+ """Queue a local console 'say' message (placeholder for TTS if you add it client-side)."""
230
+ rec = await enqueue_command("say", {"text": text}, source="mcp", priority=priority, ttl_seconds=ttl_seconds)
231
+ return {"ok": True, "queued": rec}
232
+
233
+
234
+ # ============================================================
235
+ # FASTAPI APP (REST endpoints for robot client + MCP mount)
236
+ # ============================================================
237
+ app = FastAPI(title="Create3 MCP + Queue Bridge", version="2.0.0")
238
+
239
+
240
+ # --- Security middleware for MCP HTTP endpoint (Origin + Bearer) ---
241
+ @app.middleware("http")
242
+ async def mcp_security_middleware(request: Request, call_next):
243
+ # MCP is mounted under /mcp (via Starlette mount below)
244
+ if request.url.path.startswith("/mcp"):
245
+ # Origin validation for browser-based calls (recommended by MCP transport spec)
246
+ origin = request.headers.get("origin")
247
+ if origin:
248
+ if origin not in ALLOWED_ORIGINS:
249
+ return JSONResponse(status_code=403, content={"detail": "Origin not allowed"})
250
+
251
+ # Bearer token auth
252
+ try:
253
+ check_bearer(request.headers.get("authorization"))
254
+ except HTTPException as e:
255
+ return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
256
+
257
+ # Also protect REST command endpoints
258
+ if request.url.path.startswith("/commands"):
259
+ try:
260
+ check_bearer(request.headers.get("authorization"))
261
+ except HTTPException as e:
262
+ return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
263
+
264
+ return await call_next(request)
265
+
266
+
267
+ @app.get("/")
268
+ async def root():
269
+ return {
270
+ "ok": True,
271
+ "service": "Create3 MCP + Queue Bridge",
272
+ "mcp_endpoint": "/mcp",
273
+ "time": now_iso(),
274
+ }
275
+
276
+
277
+ @app.get("/health")
278
+ async def health():
279
+ queued = sum(1 for c in commands.values() if c["status"] == "queued")
280
+ running = sum(1 for c in commands.values() if c["status"] == "running")
281
+ return {"ok": True, "queued": queued, "running": running, "total": len(commands), "time": now_iso()}
282
+
283
+
284
+ # ---- REST enqueue endpoint (optional, for testing / non-MCP senders) ----
285
+ @app.post("/commands")
286
+ async def create_command(cmd: CommandCreateREST, authorization: Optional[str] = Header(default=None)):
287
+ check_bearer(authorization)
288
+ rec = await enqueue_command(
289
+ cmd.command,
290
+ args=cmd.args,
291
+ source=cmd.source,
292
+ priority=cmd.priority,
293
+ ttl_seconds=cmd.ttl_seconds,
294
+ )
295
+ return {"ok": True, "queued": rec}
296
+
297
 
298
  @app.get("/commands")
299
  async def list_commands(limit: int = 20, authorization: Optional[str] = Header(default=None)):
300
+ check_bearer(authorization)
301
+ items = sorted(commands.values(), key=lambda x: x["updated_at"], reverse=True)[: max(1, min(limit, 100))]
 
302
  return {"ok": True, "items": items}
303
 
304
+
305
  @app.get("/commands/{command_id}")
306
+ async def get_command_rest(command_id: str, authorization: Optional[str] = Header(default=None)):
307
+ check_bearer(authorization)
308
  if command_id not in commands:
309
  raise HTTPException(status_code=404, detail="Command not found")
310
  return {"ok": True, "item": commands[command_id]}
311
 
312
+
313
  @app.post("/commands/next")
314
+ async def claim_next_command(robot_id: str, authorization: Optional[str] = Header(default=None)):
 
 
 
315
  """
316
+ Robot client polls this endpoint to claim one command.
 
317
  """
318
+ check_bearer(authorization)
319
 
320
  async with queue_lock:
 
321
  while queue:
322
  cmd_id = queue.pop(0)
323
  cmd = commands.get(cmd_id)
 
339
 
340
  return {"ok": True, "item": None}
341
 
342
+
343
  @app.post("/commands/{command_id}/status")
344
  async def update_command_status(
345
  command_id: str,
346
  update: CommandStatusUpdate,
347
  authorization: Optional[str] = Header(default=None),
348
  ):
349
+ check_bearer(authorization)
350
 
351
  cmd = commands.get(command_id)
352
  if not cmd:
353
  raise HTTPException(status_code=404, detail="Command not found")
354
 
 
355
  cmd["status"] = update.status
356
  cmd["result"] = update.result
357
  cmd["error"] = update.error
358
  if update.robot_id:
359
  cmd["claimed_by"] = update.robot_id
360
  cmd["updated_at"] = now_iso()
 
361
  return {"ok": True}
362
 
363
+
364
  @app.post("/commands/{command_id}/cancel")
365
  async def cancel_command(command_id: str, authorization: Optional[str] = Header(default=None)):
366
+ check_bearer(authorization)
367
 
368
  cmd = commands.get(command_id)
369
  if not cmd:
370
  raise HTTPException(status_code=404, detail="Command not found")
371
 
372
  if cmd["status"] in ("done", "failed", "expired"):
373
+ return {"ok": True, "status": cmd["status"]}
374
 
375
  cmd["status"] = "failed"
376
  cmd["error"] = "Cancelled by operator"
377
  cmd["updated_at"] = now_iso()
378
 
 
379
  async with queue_lock:
380
  if command_id in queue:
381
  queue.remove(command_id)
382
 
383
  return {"ok": True, "status": "failed"}
384
 
385
+
386
+ # ============================================================
387
+ # MOUNT MCP ASGI APP AT /mcp (streamable HTTP transport)
388
+ # ============================================================
389
+ # FastMCP's streamable_http_app() returns an ASGI app implementing MCP over HTTP.
390
+ # This is what Claude/ChatGPT-style MCP clients connect to.
391
+ app.router.routes.append(Mount("/mcp", app=mcp.streamable_http_app()))