TrashCollector / app.py
Mihir Mithani
Sync Hub-enabled code to Space (no weights)
a8d4cdf
"""
FastAPI server for the Garbage Collecting Robot OpenEnv environment.
Exposes reset / step / state / tasks / grade / policy / configure endpoints.
Fix applied:
- /policy BFS fallback now uses env.get_observation().dict() instead of
a hand-built incomplete dict (which was missing robot_mode, home_position,
unload_station, current_storage_load, storage_capacity, distance_from_home).
- Static files and /ui route added so the HTML dashboard is served from the
same origin — required for HuggingFace Spaces deployment.
"""
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))
from typing import List
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from environment import GarbageRobotEnv
from models import (
Action, StepOutput, ResetInput, ResetOutput, CustomResetInput, State, Task,
)
app = FastAPI(
title="Garbage Collecting Robot — OpenEnv",
description=(
"An OpenEnv-compliant robotics environment for garbage collection. "
"AI agents must navigate a grid room to pick up garbage while managing battery constraints."
),
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
env = GarbageRobotEnv()
TASKS = [
Task(
id="task_easy",
name="Small Room Clean",
description="Navigate a small 5x5 grid to collect 1 piece of garbage.",
difficulty="easy",
reward_range=[0.0, 1.0],
),
Task(
id="task_medium",
name="Medium Room with Obstacles",
description="Navigate a 7x7 grid to collect 3 pieces of garbage with limited battery.",
difficulty="medium",
reward_range=[0.0, 1.0],
),
Task(
id="task_hard",
name="Large Maze Cleanup",
description="Navigate a 10x10 maze avoiding obstacles to collect 5 pieces of garbage with strict battery usage.",
difficulty="hard",
reward_range=[0.0, 1.0],
),
]
VALID_IDS = {t.id for t in TASKS}
@app.get("/", tags=["health"])
def health():
return {"status": "ok", "env": "garbage-collecting-robot"}
@app.post("/reset", response_model=ResetOutput, tags=["openenv"])
def reset(body: ResetInput = ResetInput()):
if body.task_id not in VALID_IDS:
raise HTTPException(400, f"task_id must be one of {sorted(VALID_IDS)}")
state = env.reset(task_id=body.task_id)
return {"observation": env.get_observation().dict()}
@app.post("/reset_custom", response_model=ResetOutput, tags=["openenv"])
def reset_custom(body: CustomResetInput):
"""
Dynamic reset endpoint. Lets callers specify garbage positions,
obstacle positions, robot start, grid size and battery at runtime.
Any omitted field falls back to the base scenario's value.
"""
env.reset_custom(
task_id=body.task_id,
grid_size=body.grid_size,
robot_start=body.robot_start,
garbage_positions=body.garbage_positions,
obstacle_positions=body.obstacle_positions,
max_battery=body.max_battery,
storage_capacity=body.storage_capacity,
home_position=body.home_position,
unload_station=body.unload_station,
)
return {"observation": env.get_observation().dict()}
@app.post("/step", response_model=StepOutput, tags=["openenv"])
def step(body: Action):
result = env.step(command=body.command)
return result
@app.get("/state", response_model=State, tags=["openenv"])
def state():
return env.state()
@app.get("/tasks", response_model=list[Task], tags=["openenv"])
def tasks():
return TASKS
@app.get("/grade/{task_id}", tags=["grading"])
def grade(task_id: str):
if task_id not in VALID_IDS:
raise HTTPException(400, f"task_id must be one of {sorted(VALID_IDS)}")
score = env.grade(task_id)
return {"task_id": task_id, "score": score, "reward_range": [0.0, 1.0]}
# ── Policy endpoint (fine-tuned LLM) ──────────────────────────────────────
LOCAL_MODEL_PATH = os.environ.get(
"LOCAL_MODEL_PATH",
"TechAvenger/GarbageBot-Weights"
)
_policy_model = None
_policy_tokenizer = None
_policy_loaded = False
def _load_policy():
global _policy_model, _policy_tokenizer, _policy_loaded
if _policy_loaded:
return
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
_policy_tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH)
_policy_model = AutoModelForCausalLM.from_pretrained(
LOCAL_MODEL_PATH, torch_dtype=torch.float16, device_map="auto"
)
_policy_model.eval()
print(f"[Policy] Fine-tuned model loaded from {LOCAL_MODEL_PATH}")
except Exception as e:
print(f"[Policy] Model unavailable: {e}")
_policy_loaded = True
class PolicyInput(BaseModel):
message: str # the obs.message string from the environment
@app.post("/policy", tags=["openenv"])
def policy(body: PolicyInput):
"""
Ask the fine-tuned LLM for the next action.
Returns {"action": "UP|DOWN|LEFT|RIGHT|COLLECT", "source": "llm|bfs"}
"""
_load_policy()
VALID = ["UP", "DOWN", "LEFT", "RIGHT", "COLLECT"]
if _policy_model is not None and _policy_tokenizer is not None:
try:
import torch
instruction = (
"You are an AI brain controlling a garbage collecting robot.\n"
"Reply with EXACTLY ONE of: UP DOWN LEFT RIGHT COLLECT"
)
prompt = (
f"### Instruction:\n{instruction}\n\n"
f"### Input:\nENVIRONMENT STATUS:\n{body.message}\n\n"
f"### Response:\n"
)
inputs = _policy_tokenizer(
prompt, return_tensors="pt", truncation=True, max_length=512
).to(_policy_model.device)
with torch.no_grad():
outputs = _policy_model.generate(
**inputs, max_new_tokens=6, do_sample=False,
pad_token_id=_policy_tokenizer.eos_token_id
)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
raw = _policy_tokenizer.decode(new_tokens, skip_special_tokens=True).strip().upper()
for v in VALID:
if v in raw:
return {"action": v, "source": "llm", "raw": raw}
except Exception as e:
print(f"[Policy] Inference error: {e}")
# FIX: use env.get_observation().dict() so heuristic_action() receives
# all required fields (robot_mode, home_position, unload_station, etc.)
# instead of the previous hand-built incomplete dict.
from inference import heuristic_action
obs_dict = env.get_observation().dict()
obs_dict["message"] = body.message # use the caller's message for context
return {"action": heuristic_action(obs_dict), "source": "bfs"}
# ── Dynamic garbage placement ──────────────────────────────────────────────
class ConfigureInput(BaseModel):
task_id: str = "task_easy"
garbage_positions: List[List[int]] # [[x,y], ...]
@app.post("/configure", tags=["openenv"])
def configure(body: ConfigureInput):
"""
Reset the environment for task_id, then override garbage positions
with whatever the caller supplies.
"""
if body.task_id not in VALID_IDS:
raise HTTPException(400, f"task_id must be one of {sorted(VALID_IDS)}")
env.reset(task_id=body.task_id)
validated = []
for pos in body.garbage_positions:
if len(pos) != 2:
raise HTTPException(400, f"Each position must be [x, y], got {pos}")
x, y = pos
gw, gh = env.grid_size
if not (0 <= x < gw and 0 <= y < gh):
raise HTTPException(400, f"Position {pos} out of bounds for grid {env.grid_size}")
if [x, y] in env.obstacle_positions:
raise HTTPException(400, f"Position {pos} is an obstacle")
validated.append([x, y])
env.garbage_positions = validated
return {"observation": env.get_observation().dict()}
# ── Serve HTML dashboard ───────────────────────────────────────────────────
# This makes the frontend accessible at /ui on the same origin as the API,
# which is required for HuggingFace Spaces (no localhost cross-origin issues).
@app.get("/ui", include_in_schema=False)
def ui():
"""Serve the dashboard HTML."""
return FileResponse("frontend/index.html")
# Mount static assets (style.css, script.js) at /static
if os.path.exists("frontend/style.css") or os.path.exists("frontend/script.js"):
app.mount("/static", StaticFiles(directory="frontend"), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)