SpandanM110's picture
Fix OpenEnv validation: add uv.lock, scripts entry, openenv-core dep, main()
59d6790
"""FastAPI server for OCR Table RL Environment (with Gradio UI mounted at /ui)."""
from __future__ import annotations
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
import gradio as gr
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from typing import Optional
from env.environment import OCREnvironment
from env.models import OCRAction, OCRObservation, OCRState
from env.tasks import TASK_METADATA
# ---------------------------------------------------------------------------
# App setup
# ---------------------------------------------------------------------------
app = FastAPI(
title="OCR Table RL Environment",
description=(
"OpenEnv-compatible RL environment for structured table extraction. "
"Agents extract complex tables into Markdown + JSON KPIs from synthetic document images."
),
version="1.0.0",
docs_url="/docs",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global environment instance (one session at a time — sufficient for eval)
env = OCREnvironment()
# ---------------------------------------------------------------------------
# Request/Response schemas
# ---------------------------------------------------------------------------
class ResetRequest(BaseModel):
task: str = "clean_table"
class StepResponse(BaseModel):
observation: OCRObservation
reward: float
done: bool
info: dict
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/")
async def root():
return RedirectResponse(url="/ui")
@app.get("/health")
async def health():
return {"status": "healthy", "env": "ocr-table-rl", "version": "1.0.0"}
@app.get("/tasks")
async def list_tasks():
return {"tasks": TASK_METADATA}
@app.post("/reset")
async def reset(request: Request):
"""Reset the environment for a given task and return the initial observation."""
valid_tasks = {"clean_table", "noisy_financial", "degraded_report"}
task = "clean_table"
try:
body = await request.json()
if isinstance(body, dict) and body.get("task") in valid_tasks:
task = body["task"]
except Exception:
pass
obs = env.reset(task=task)
return obs.model_dump()
@app.post("/step", response_model=StepResponse)
async def step(action: OCRAction):
"""Execute one environment step with the given action."""
obs, reward, done, info = env.step(action)
return StepResponse(
observation=obs,
reward=reward,
done=done,
info=info,
)
@app.get("/state", response_model=OCRState)
async def get_state():
"""Return the current environment state."""
return env.state()
# ---------------------------------------------------------------------------
# Mount Gradio UI at /ui
# ---------------------------------------------------------------------------
try:
import importlib.util
spec = importlib.util.spec_from_file_location(
"gradio_app",
os.path.join(os.path.dirname(os.path.dirname(__file__)), "app.py"),
)
gradio_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gradio_module)
demo = gradio_module.demo
app = gr.mount_gradio_app(app, demo, path="/ui")
except Exception as e:
print(f"Warning: Could not mount Gradio UI: {e}")
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main():
"""Main entry point for the OpenEnv server."""
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port, reload=False)
if __name__ == "__main__":
main()