File size: 4,056 Bytes
5d245c2
abe7bd6
 
 
 
 
 
5d245c2
 
8702c9b
abe7bd6
715d086
abe7bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715d086
 
 
 
 
abe7bd6
 
 
 
 
 
 
 
 
 
8702c9b
 
abe7bd6
 
8702c9b
 
 
 
 
 
 
abe7bd6
8702c9b
abe7bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d245c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abe7bd6
 
 
 
59d6790
 
abe7bd6
5d245c2
59d6790
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""FastAPI server for OCR Table RL Environment (with Gradio UI mounted at /ui)."""
from __future__ import annotations
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

import gradio as gr

from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from typing import Optional

from env.environment import OCREnvironment
from env.models import OCRAction, OCRObservation, OCRState
from env.tasks import TASK_METADATA

# ---------------------------------------------------------------------------
# App setup
# ---------------------------------------------------------------------------

app = FastAPI(
    title="OCR Table RL Environment",
    description=(
        "OpenEnv-compatible RL environment for structured table extraction. "
        "Agents extract complex tables into Markdown + JSON KPIs from synthetic document images."
    ),
    version="1.0.0",
    docs_url="/docs",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global environment instance (one session at a time — sufficient for eval)
env = OCREnvironment()


# ---------------------------------------------------------------------------
# Request/Response schemas
# ---------------------------------------------------------------------------

class ResetRequest(BaseModel):
    task: str = "clean_table"


class StepResponse(BaseModel):
    observation: OCRObservation
    reward: float
    done: bool
    info: dict


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.get("/")
async def root():
    return RedirectResponse(url="/ui")


@app.get("/health")
async def health():
    return {"status": "healthy", "env": "ocr-table-rl", "version": "1.0.0"}


@app.get("/tasks")
async def list_tasks():
    return {"tasks": TASK_METADATA}


@app.post("/reset")
async def reset(request: Request):
    """Reset the environment for a given task and return the initial observation."""
    valid_tasks = {"clean_table", "noisy_financial", "degraded_report"}
    task = "clean_table"
    try:
        body = await request.json()
        if isinstance(body, dict) and body.get("task") in valid_tasks:
            task = body["task"]
    except Exception:
        pass
    obs = env.reset(task=task)
    return obs.model_dump()


@app.post("/step", response_model=StepResponse)
async def step(action: OCRAction):
    """Execute one environment step with the given action."""
    obs, reward, done, info = env.step(action)
    return StepResponse(
        observation=obs,
        reward=reward,
        done=done,
        info=info,
    )


@app.get("/state", response_model=OCRState)
async def get_state():
    """Return the current environment state."""
    return env.state()


# ---------------------------------------------------------------------------
# Mount Gradio UI at /ui
# ---------------------------------------------------------------------------

try:
    import importlib.util
    spec = importlib.util.spec_from_file_location(
        "gradio_app",
        os.path.join(os.path.dirname(os.path.dirname(__file__)), "app.py"),
    )
    gradio_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(gradio_module)
    demo = gradio_module.demo
    app = gr.mount_gradio_app(app, demo, path="/ui")
except Exception as e:
    print(f"Warning: Could not mount Gradio UI: {e}")


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main():
    """Main entry point for the OpenEnv server."""
    import uvicorn
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port, reload=False)


if __name__ == "__main__":
    main()