File size: 3,667 Bytes
0bf71ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02956e
 
0bf71ce
b02956e
 
 
 
0bf71ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02956e
 
 
 
 
 
0bf71ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
Python client for the Invoice Processing Pipeline environment.

Usage:
    from client import InvoiceEnvClient
    from models import InvoiceAction

    client = InvoiceEnvClient(base_url="http://localhost:7860")
    result = client.reset(task_id="easy")
    print(result["observation"]["raw_text"])

    result = client.step({"vendor": "Acme Corp", "date": "2024-06-15", ...})
    print(result["reward"])
"""

from __future__ import annotations

from typing import Any, Dict, Optional

import httpx


class InvoiceEnvClient:
    """Synchronous HTTP client for the Invoice Processing Pipeline."""

    def __init__(self, base_url: str = "http://localhost:7860", timeout: float = 30.0):
        self.base_url = base_url.rstrip("/")
        self._client = httpx.Client(timeout=timeout)

    def reset(self, task_id: str = "easy") -> Dict[str, Any]:
        """Reset the environment for a new episode."""
        resp = self._client.post(f"{self.base_url}/reset", json={"task_id": task_id})
        resp.raise_for_status()
        return resp.json()

    def step(self, extracted_data: Dict[str, Any], explanation: str = "",
             episode_id: Optional[str] = None) -> Dict[str, Any]:
        """Submit extracted/cleaned data and get reward + feedback."""
        body: Dict[str, Any] = {"extracted_data": extracted_data, "explanation": explanation}
        if episode_id is not None:
            body["episode_id"] = episode_id
        resp = self._client.post(f"{self.base_url}/step", json=body)
        resp.raise_for_status()
        return resp.json()

    def state(self) -> Dict[str, Any]:
        """Get current episode state."""
        resp = self._client.get(f"{self.base_url}/state")
        resp.raise_for_status()
        return resp.json()

    def tasks(self) -> Dict[str, Any]:
        """List available tasks and schemas."""
        resp = self._client.get(f"{self.base_url}/tasks")
        resp.raise_for_status()
        return resp.json()

    def health(self) -> Dict[str, Any]:
        """Check server health."""
        resp = self._client.get(f"{self.base_url}/health")
        resp.raise_for_status()
        return resp.json()

    def close(self):
        """Close the HTTP client."""
        self._client.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()


class AsyncInvoiceEnvClient:
    """Async HTTP client for the Invoice Processing Pipeline."""

    def __init__(self, base_url: str = "http://localhost:7860", timeout: float = 30.0):
        self.base_url = base_url.rstrip("/")
        self._client = httpx.AsyncClient(timeout=timeout)

    async def reset(self, task_id: str = "easy") -> Dict[str, Any]:
        resp = await self._client.post(f"{self.base_url}/reset", json={"task_id": task_id})
        resp.raise_for_status()
        return resp.json()

    async def step(self, extracted_data: Dict[str, Any], explanation: str = "",
                   episode_id: Optional[str] = None) -> Dict[str, Any]:
        body: Dict[str, Any] = {"extracted_data": extracted_data, "explanation": explanation}
        if episode_id is not None:
            body["episode_id"] = episode_id
        resp = await self._client.post(f"{self.base_url}/step", json=body)
        resp.raise_for_status()
        return resp.json()

    async def state(self) -> Dict[str, Any]:
        resp = await self._client.get(f"{self.base_url}/state")
        resp.raise_for_status()
        return resp.json()

    async def close(self):
        await self._client.aclose()

    async def __aenter__(self):
        return self

    async def __aexit__(self, *args):
        await self.close()