ricalanis commited on
Commit
84ca609
·
verified ·
1 Parent(s): 0de736f

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
2
+ FROM ${BASE_IMAGE} AS builder
3
+
4
+ WORKDIR /app
5
+
6
+ RUN apt-get update && \
7
+ apt-get install -y --no-install-recommends git && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ ARG BUILD_MODE=in-repo
11
+ ARG ENV_NAME=datasage_cleaning
12
+
13
+ COPY . /app/env
14
+
15
+ WORKDIR /app/env
16
+
17
+ RUN if ! command -v uv >/dev/null 2>&1; then \
18
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
19
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
20
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
21
+ fi
22
+
23
+ RUN --mount=type=cache,target=/root/.cache/uv \
24
+ if [ -f uv.lock ]; then \
25
+ uv sync --frozen --no-install-project --no-editable; \
26
+ else \
27
+ uv sync --no-install-project --no-editable; \
28
+ fi
29
+
30
+ RUN --mount=type=cache,target=/root/.cache/uv \
31
+ if [ -f uv.lock ]; then \
32
+ uv sync --frozen --no-editable; \
33
+ else \
34
+ uv sync --no-editable; \
35
+ fi
36
+
37
+ FROM ${BASE_IMAGE}
38
+
39
+ WORKDIR /app
40
+
41
+ COPY --from=builder /app/env/.venv /app/.venv
42
+ COPY --from=builder /app/env /app/env
43
+
44
+ ENV PATH="/app/.venv/bin:$PATH"
45
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
46
+
47
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
48
+ CMD curl -f http://localhost:8000/health || exit 1
49
+
50
+ ENV ENABLE_WEB_INTERFACE=true
51
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,34 @@
1
  ---
2
- title: Datasage Cleaning
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: red
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DataSage Cleaning Environment
3
+ emoji: 🧹
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # DataSage Cleaning Environment
15
+
16
+ An RL environment for training agents to clean enterprise data across four domains (HR, Sales, Project Management, IT Operations).
17
+
18
+ The agent receives a corrupted 50-row data batch and must apply cleaning operations (fill nulls, fix types, remove duplicates, standardize values, trim whitespace, correct typos) to maximise a composite data quality score. Episodes end when DQ > 0.95 or after 15 steps.
19
+
20
+ ## Quick Start
21
+
22
+ ```python
23
+ from environments.cleaning.models import CleaningAction
24
+ from environments.cleaning.client import CleaningEnv
25
+
26
+ with CleaningEnv(base_url="http://localhost:8000") as env:
27
+ result = env.reset()
28
+ print(f"Domain: {result.observation.domain}, DQ: {result.observation.dq_score}")
29
+
30
+ result = env.step(CleaningAction(
31
+ operation="fill_null", column="Age", value="median"
32
+ ))
33
+ print(f"DQ after step: {result.observation.dq_score}")
34
+ ```
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """DataSage Cleaning Environment."""
client.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DataSage Cleaning Environment Client."""
2
+
3
+ from typing import Dict
4
+
5
+ from openenv.core import EnvClient
6
+ from openenv.core.client_types import StepResult
7
+ from openenv.core.env_server.types import State
8
+
9
+ from .models import CleaningAction, CleaningObservation
10
+
11
+
12
+ class CleaningEnv(EnvClient[CleaningAction, CleaningObservation, State]):
13
+ """
14
+ Client for the DataSage Cleaning Environment.
15
+
16
+ Example:
17
+ >>> with CleaningEnv(base_url="http://localhost:8000") as client:
18
+ ... result = client.reset()
19
+ ... print(result.observation.dq_score)
20
+ ... result = client.step(CleaningAction(
21
+ ... operation="fill_null", column="Age", value="median"
22
+ ... ))
23
+ ... print(result.observation.dq_score)
24
+ """
25
+
26
+ def _step_payload(self, action: CleaningAction) -> Dict:
27
+ """Convert CleaningAction to JSON payload."""
28
+ return {
29
+ "operation": action.operation,
30
+ "column": action.column,
31
+ "value": action.value,
32
+ "params": action.params,
33
+ }
34
+
35
+ def _parse_result(self, payload: Dict) -> StepResult[CleaningObservation]:
36
+ """Parse server response into StepResult[CleaningObservation]."""
37
+ obs_data = payload.get("observation", {})
38
+ observation = CleaningObservation(
39
+ domain=obs_data.get("domain", ""),
40
+ data_preview=obs_data.get("data_preview", ""),
41
+ dq_report=obs_data.get("dq_report", ""),
42
+ dq_score=obs_data.get("dq_score", 0.0),
43
+ columns_info=obs_data.get("columns_info", ""),
44
+ step_number=obs_data.get("step_number", 0),
45
+ max_steps=obs_data.get("max_steps", 15),
46
+ done=payload.get("done", False),
47
+ reward=payload.get("reward"),
48
+ metadata=obs_data.get("metadata", {}),
49
+ )
50
+
51
+ return StepResult(
52
+ observation=observation,
53
+ reward=payload.get("reward"),
54
+ done=payload.get("done", False),
55
+ )
56
+
57
+ def _parse_state(self, payload: Dict) -> State:
58
+ """Parse server response into State object."""
59
+ return State(
60
+ episode_id=payload.get("episode_id"),
61
+ step_count=payload.get("step_count", 0),
62
+ )
environments/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # DataSage environments package
environments/shared/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared utilities for DataSage multi-domain enterprise environments."""
2
+
3
+ from .domains import DOMAINS, DomainConfig
4
+ from .personas import PERSONAS, Persona, score_persona_alignment
5
+ from .reward_utils import cleaning_reward, enrichment_reward, answering_reward
6
+
7
+ __all__ = [
8
+ "DOMAINS", "DomainConfig",
9
+ "PERSONAS", "Persona", "score_persona_alignment",
10
+ "cleaning_reward", "enrichment_reward", "answering_reward",
11
+ ]
environments/shared/domains.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Domain registry for the 4 enterprise data domains."""
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class DomainConfig(BaseModel):
7
+ name: str
8
+ display_name: str
9
+ dataset_key: str
10
+ columns: list[str]
11
+ numeric_columns: list[str]
12
+ categorical_columns: list[str]
13
+ possible_enrichments: list[str]
14
+ example_questions: list[str]
15
+
16
+
17
+ DOMAINS = {
18
+ "hr": DomainConfig(
19
+ name="hr",
20
+ display_name="HR & People",
21
+ dataset_key="hr",
22
+ columns=[
23
+ "EmployeeID", "Age", "Department", "JobRole", "MonthlyIncome",
24
+ "YearsAtCompany", "Attrition", "JobSatisfaction", "OverTime",
25
+ "DistanceFromHome", "Education", "PerformanceRating",
26
+ ],
27
+ numeric_columns=["Age", "MonthlyIncome", "YearsAtCompany", "DistanceFromHome"],
28
+ categorical_columns=["Department", "JobRole", "Attrition", "OverTime"],
29
+ possible_enrichments=[
30
+ "salary_band", "tenure_risk", "satisfaction_index",
31
+ "industry_benchmark", "flight_risk_score",
32
+ ],
33
+ example_questions=[
34
+ "Which departments have the highest attrition rates?",
35
+ "What factors correlate most with employee turnover?",
36
+ "How does overtime affect job satisfaction?",
37
+ "What is the salary distribution across job roles?",
38
+ "Which employees are at highest flight risk?",
39
+ ],
40
+ ),
41
+ "sales": DomainConfig(
42
+ name="sales",
43
+ display_name="Sales & Revenue",
44
+ dataset_key="sales",
45
+ columns=[
46
+ "DealID", "AccountName", "Stage", "Amount", "CloseDate",
47
+ "Rep", "Product", "Region", "LeadSource", "DaysInStage",
48
+ "Probability", "ForecastCategory",
49
+ ],
50
+ numeric_columns=["Amount", "DaysInStage", "Probability"],
51
+ categorical_columns=["Stage", "Region", "Product", "ForecastCategory"],
52
+ possible_enrichments=[
53
+ "deal_size_category", "velocity_score", "win_probability_model",
54
+ "industry_code", "competitive_risk",
55
+ ],
56
+ example_questions=[
57
+ "What's our pipeline health for this quarter?",
58
+ "Which deals are at risk of slipping?",
59
+ "What's the average deal velocity by region?",
60
+ "Which reps are below quota?",
61
+ "What's the conversion rate by lead source?",
62
+ ],
63
+ ),
64
+ "pm": DomainConfig(
65
+ name="pm",
66
+ display_name="Project Management",
67
+ dataset_key="pm",
68
+ columns=[
69
+ "TaskID", "ProjectName", "Assignee", "Status", "Priority",
70
+ "DueDate", "EstimatedHours", "ActualHours", "Dependencies",
71
+ "Milestone", "RiskFlag", "CompletionPct",
72
+ ],
73
+ numeric_columns=["EstimatedHours", "ActualHours", "CompletionPct"],
74
+ categorical_columns=["Status", "Priority", "RiskFlag"],
75
+ possible_enrichments=[
76
+ "schedule_risk_score", "resource_utilization",
77
+ "dependency_chain_depth", "burndown_rate", "delay_probability",
78
+ ],
79
+ example_questions=[
80
+ "Which projects are at risk of missing deadlines?",
81
+ "How is resource utilization across teams?",
82
+ "What's the burndown rate for the current sprint?",
83
+ "Which tasks are blocking the most downstream work?",
84
+ "What's our on-time delivery rate?",
85
+ ],
86
+ ),
87
+ "it_ops": DomainConfig(
88
+ name="it_ops",
89
+ display_name="IT Operations",
90
+ dataset_key="it_ops",
91
+ columns=[
92
+ "TicketID", "Category", "Priority", "Status", "Assignee",
93
+ "CreatedDate", "ResolvedDate", "SLATarget", "EscalationLevel",
94
+ "AffectedSystem", "ResolutionType", "CustomerImpact",
95
+ ],
96
+ numeric_columns=["SLATarget", "EscalationLevel"],
97
+ categorical_columns=["Category", "Priority", "Status", "ResolutionType"],
98
+ possible_enrichments=[
99
+ "sla_compliance_flag", "mttr_band", "escalation_path",
100
+ "incident_severity_score", "recurring_pattern_flag",
101
+ ],
102
+ example_questions=[
103
+ "What's our SLA compliance rate this month?",
104
+ "Which systems have the most incidents?",
105
+ "What's the mean time to resolution trend?",
106
+ "How many tickets are breaching SLA?",
107
+ "What are the most common root causes?",
108
+ ],
109
+ ),
110
+ }
environments/shared/enrichment_sources.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Static enrichment lookup tables per domain (no API calls)."""
2
+
3
+ import numpy as np
4
+
5
+ # Enrichment registry: domain -> source -> lookup function or static data
6
+ ENRICHMENT_REGISTRY: dict[str, dict[str, dict]] = {
7
+ "hr": {
8
+ "salary_band": {
9
+ "description": "BLS salary band classification based on monthly income",
10
+ "type": "derived",
11
+ "logic": "classify_salary_band",
12
+ },
13
+ "tenure_risk": {
14
+ "description": "Tenure-based flight risk score",
15
+ "type": "derived",
16
+ "logic": "compute_tenure_risk",
17
+ },
18
+ "satisfaction_index": {
19
+ "description": "Composite satisfaction index from multiple factors",
20
+ "type": "derived",
21
+ "logic": "compute_satisfaction_index",
22
+ },
23
+ "industry_benchmark": {
24
+ "description": "Industry benchmark salary percentile",
25
+ "type": "lookup",
26
+ "data": {
27
+ "Sales Executive": 65000, "Research Scientist": 72000,
28
+ "Manager": 85000, "Lab Technician": 45000,
29
+ "Manufacturing Director": 95000, "Healthcare Representative": 55000,
30
+ "Human Resources": 60000,
31
+ },
32
+ },
33
+ "flight_risk_score": {
34
+ "description": "Combined flight risk from satisfaction, tenure, overtime",
35
+ "type": "derived",
36
+ "logic": "compute_flight_risk",
37
+ },
38
+ },
39
+ "sales": {
40
+ "deal_size_category": {
41
+ "description": "Categorize deal by amount: Small/Medium/Large/Enterprise",
42
+ "type": "derived",
43
+ "logic": "classify_deal_size",
44
+ },
45
+ "velocity_score": {
46
+ "description": "Deal velocity based on days in stage vs benchmark",
47
+ "type": "derived",
48
+ "logic": "compute_velocity_score",
49
+ },
50
+ "win_probability_model": {
51
+ "description": "Heuristic win probability based on stage + days",
52
+ "type": "derived",
53
+ "logic": "compute_win_probability",
54
+ },
55
+ "industry_code": {
56
+ "description": "Industry classification code from account name patterns",
57
+ "type": "lookup",
58
+ "data": {
59
+ "Tech": "SIC-7372", "Healthcare": "SIC-8011",
60
+ "Finance": "SIC-6020", "Retail": "SIC-5311",
61
+ "Manufacturing": "SIC-3559", "default": "SIC-9999",
62
+ },
63
+ },
64
+ "competitive_risk": {
65
+ "description": "Competitive risk score based on deal stage and velocity",
66
+ "type": "derived",
67
+ "logic": "compute_competitive_risk",
68
+ },
69
+ },
70
+ "pm": {
71
+ "schedule_risk_score": {
72
+ "description": "Risk of schedule slippage based on progress vs due date",
73
+ "type": "derived",
74
+ "logic": "compute_schedule_risk",
75
+ },
76
+ "resource_utilization": {
77
+ "description": "Resource utilization ratio: actual/estimated hours",
78
+ "type": "derived",
79
+ "logic": "compute_resource_utilization",
80
+ },
81
+ "dependency_chain_depth": {
82
+ "description": "Depth of dependency chain for task",
83
+ "type": "derived",
84
+ "logic": "compute_dependency_depth",
85
+ },
86
+ "burndown_rate": {
87
+ "description": "Task completion rate relative to plan",
88
+ "type": "derived",
89
+ "logic": "compute_burndown_rate",
90
+ },
91
+ "delay_probability": {
92
+ "description": "Probability of delay based on current trajectory",
93
+ "type": "derived",
94
+ "logic": "compute_delay_probability",
95
+ },
96
+ },
97
+ "it_ops": {
98
+ "sla_compliance_flag": {
99
+ "description": "Whether ticket meets SLA target",
100
+ "type": "derived",
101
+ "logic": "compute_sla_compliance",
102
+ },
103
+ "mttr_band": {
104
+ "description": "Mean time to resolution band: Fast/Normal/Slow/Critical",
105
+ "type": "derived",
106
+ "logic": "classify_mttr",
107
+ },
108
+ "escalation_path": {
109
+ "description": "Recommended escalation path based on category and priority",
110
+ "type": "lookup",
111
+ "data": {
112
+ "P1-Critical": "L3 -> Manager -> VP",
113
+ "P2-High": "L2 -> L3 -> Manager",
114
+ "P3-Medium": "L1 -> L2",
115
+ "P4-Low": "L1",
116
+ },
117
+ },
118
+ "incident_severity_score": {
119
+ "description": "Computed severity score from priority and customer impact",
120
+ "type": "derived",
121
+ "logic": "compute_severity_score",
122
+ },
123
+ "recurring_pattern_flag": {
124
+ "description": "Flag indicating likely recurring issue",
125
+ "type": "derived",
126
+ "logic": "detect_recurring_pattern",
127
+ },
128
+ },
129
+ }
130
+
131
+
132
+ def lookup(domain: str, source: str, row: dict) -> object:
133
+ """Unified lookup/compute function for enrichment values."""
134
+ registry = ENRICHMENT_REGISTRY.get(domain, {})
135
+ source_config = registry.get(source)
136
+ if not source_config:
137
+ return None
138
+
139
+ if source_config["type"] == "lookup":
140
+ # Direct lookup from static data
141
+ data = source_config["data"]
142
+ # Try various keys from the row
143
+ for key_col in row:
144
+ val = str(row.get(key_col, ""))
145
+ if val in data:
146
+ return data[val]
147
+ return data.get("default")
148
+
149
+ # Derived computations
150
+ logic = source_config["logic"]
151
+ compute_fn = _COMPUTE_FUNCTIONS.get(logic)
152
+ if compute_fn:
153
+ return compute_fn(row)
154
+ return None
155
+
156
+
157
+ # --- Computation functions ---
158
+
159
+ def _classify_salary_band(row: dict) -> str:
160
+ try:
161
+ income = float(row.get("MonthlyIncome", 0))
162
+ except (ValueError, TypeError):
163
+ return "Unknown"
164
+ if income < 3000:
165
+ return "Entry"
166
+ elif income < 6000:
167
+ return "Mid"
168
+ elif income < 10000:
169
+ return "Senior"
170
+ return "Executive"
171
+
172
+
173
+ def _compute_tenure_risk(row: dict) -> float:
174
+ try:
175
+ years = float(row.get("YearsAtCompany", 0))
176
+ except (ValueError, TypeError):
177
+ return 0.5
178
+ # Short tenure = higher risk, very long = moderate risk
179
+ if years < 2:
180
+ return 0.8
181
+ elif years < 5:
182
+ return 0.4
183
+ elif years < 10:
184
+ return 0.2
185
+ return 0.3
186
+
187
+
188
+ def _compute_satisfaction_index(row: dict) -> float:
189
+ try:
190
+ satisfaction = float(row.get("JobSatisfaction", 3))
191
+ except (ValueError, TypeError):
192
+ satisfaction = 3
193
+ return round(satisfaction / 4.0, 2)
194
+
195
+
196
+ def _compute_flight_risk(row: dict) -> float:
197
+ tenure_risk = _compute_tenure_risk(row)
198
+ sat_index = _compute_satisfaction_index(row)
199
+ overtime = 0.3 if str(row.get("OverTime", "No")).lower() == "yes" else 0.0
200
+ return round(0.4 * tenure_risk + 0.4 * (1 - sat_index) + 0.2 * overtime, 2)
201
+
202
+
203
+ def _classify_deal_size(row: dict) -> str:
204
+ try:
205
+ amount = float(row.get("Amount", 0))
206
+ except (ValueError, TypeError):
207
+ return "Unknown"
208
+ if amount < 5000:
209
+ return "Small"
210
+ elif amount < 25000:
211
+ return "Medium"
212
+ elif amount < 100000:
213
+ return "Large"
214
+ return "Enterprise"
215
+
216
+
217
+ def _compute_velocity_score(row: dict) -> float:
218
+ try:
219
+ days = float(row.get("DaysInStage", 0))
220
+ except (ValueError, TypeError):
221
+ return 0.5
222
+ # Benchmark: 30 days per stage
223
+ if days < 15:
224
+ return 1.0
225
+ elif days < 30:
226
+ return 0.7
227
+ elif days < 60:
228
+ return 0.4
229
+ return 0.1
230
+
231
+
232
+ def _compute_win_probability(row: dict) -> float:
233
+ stage_probs = {
234
+ "Prospecting": 0.10, "Qualification": 0.25, "Proposal": 0.50,
235
+ "Negotiation": 0.75, "Won": 1.0, "Lost": 0.0,
236
+ }
237
+ stage = str(row.get("Stage", ""))
238
+ base_prob = stage_probs.get(stage, 0.3)
239
+ velocity = _compute_velocity_score(row)
240
+ return round(0.7 * base_prob + 0.3 * velocity, 2)
241
+
242
+
243
+ def _compute_competitive_risk(row: dict) -> float:
244
+ velocity = _compute_velocity_score(row)
245
+ return round(1.0 - velocity, 2)
246
+
247
+
248
+ def _compute_schedule_risk(row: dict) -> float:
249
+ try:
250
+ pct = float(row.get("CompletionPct", 0))
251
+ except (ValueError, TypeError):
252
+ pct = 0
253
+ # Simple: lower completion = higher risk
254
+ return round(1.0 - (pct / 100.0), 2)
255
+
256
+
257
+ def _compute_resource_utilization(row: dict) -> float:
258
+ try:
259
+ estimated = float(row.get("EstimatedHours", 1))
260
+ actual = float(row.get("ActualHours", 0))
261
+ except (ValueError, TypeError):
262
+ return 0.0
263
+ if estimated == 0:
264
+ return 0.0
265
+ return round(actual / estimated, 2)
266
+
267
+
268
+ def _compute_dependency_depth(row: dict) -> int:
269
+ deps = row.get("Dependencies", "")
270
+ if not deps or str(deps) in ("nan", "None", ""):
271
+ return 0
272
+ return len(str(deps).split(","))
273
+
274
+
275
+ def _compute_burndown_rate(row: dict) -> float:
276
+ try:
277
+ pct = float(row.get("CompletionPct", 0))
278
+ estimated = float(row.get("EstimatedHours", 1))
279
+ actual = float(row.get("ActualHours", 0))
280
+ except (ValueError, TypeError):
281
+ return 0.5
282
+ if actual == 0:
283
+ return 0.0
284
+ expected_rate = pct / 100.0
285
+ time_rate = actual / max(estimated, 1)
286
+ return round(expected_rate / max(time_rate, 0.01), 2)
287
+
288
+
289
+ def _compute_delay_probability(row: dict) -> float:
290
+ schedule_risk = _compute_schedule_risk(row)
291
+ burndown = _compute_burndown_rate(row)
292
+ return round(schedule_risk * (1.0 / max(burndown, 0.1)), 2)
293
+
294
+
295
+ def _compute_sla_compliance(row: dict) -> str:
296
+ try:
297
+ sla = float(row.get("SLATarget", 24))
298
+ escalation = float(row.get("EscalationLevel", 0))
299
+ except (ValueError, TypeError):
300
+ return "Unknown"
301
+ if escalation > 2:
302
+ return "Breached"
303
+ return "Compliant"
304
+
305
+
306
+ def _classify_mttr(row: dict) -> str:
307
+ try:
308
+ escalation = float(row.get("EscalationLevel", 0))
309
+ except (ValueError, TypeError):
310
+ return "Normal"
311
+ if escalation == 0:
312
+ return "Fast"
313
+ elif escalation <= 1:
314
+ return "Normal"
315
+ elif escalation <= 3:
316
+ return "Slow"
317
+ return "Critical"
318
+
319
+
320
+ def _compute_severity_score(row: dict) -> float:
321
+ priority_scores = {"P1-Critical": 1.0, "P2-High": 0.7, "P3-Medium": 0.4, "P4-Low": 0.1}
322
+ priority = str(row.get("Priority", "P3-Medium"))
323
+ return priority_scores.get(priority, 0.4)
324
+
325
+
326
+ def _detect_recurring_pattern(row: dict) -> bool:
327
+ category = str(row.get("Category", ""))
328
+ # Simple heuristic: certain categories tend to recur
329
+ recurring_cats = {"Network", "Email", "Access"}
330
+ return category in recurring_cats
331
+
332
+
333
+ _COMPUTE_FUNCTIONS = {
334
+ "classify_salary_band": _classify_salary_band,
335
+ "compute_tenure_risk": _compute_tenure_risk,
336
+ "compute_satisfaction_index": _compute_satisfaction_index,
337
+ "compute_flight_risk": _compute_flight_risk,
338
+ "classify_deal_size": _classify_deal_size,
339
+ "compute_velocity_score": _compute_velocity_score,
340
+ "compute_win_probability": _compute_win_probability,
341
+ "compute_competitive_risk": _compute_competitive_risk,
342
+ "compute_schedule_risk": _compute_schedule_risk,
343
+ "compute_resource_utilization": _compute_resource_utilization,
344
+ "compute_dependency_depth": _compute_dependency_depth,
345
+ "compute_burndown_rate": _compute_burndown_rate,
346
+ "compute_delay_probability": _compute_delay_probability,
347
+ "compute_sla_compliance": _compute_sla_compliance,
348
+ "classify_mttr": _classify_mttr,
349
+ "compute_severity_score": _compute_severity_score,
350
+ "detect_recurring_pattern": _detect_recurring_pattern,
351
+ }
352
+
353
+
354
+ def get_available_enrichments(domain: str) -> list[str]:
355
+ """Return list of available enrichment source names for a domain."""
356
+ return list(ENRICHMENT_REGISTRY.get(domain, {}).keys())
357
+
358
+
359
+ def get_enrichment_description(domain: str, source: str) -> str:
360
+ """Get human-readable description of an enrichment source."""
361
+ registry = ENRICHMENT_REGISTRY.get(domain, {})
362
+ config = registry.get(source, {})
363
+ return config.get("description", "Unknown enrichment source")
environments/shared/enterprise_data.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-domain dataset loading, corruption injection, and DQ scoring."""
2
+
3
+ import random
4
+ import string
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from .domains import DOMAINS, DomainConfig
11
+
12
+
13
+ def load_domain_data(domain: str, sample_size: Optional[int] = None) -> pd.DataFrame:
14
+ """Load domain data from HF dataset or generate synthetic fallback."""
15
+ try:
16
+ from datasets import load_dataset
17
+ ds = load_dataset("ricalanis/datasage-enterprise-raw", domain, split="train")
18
+ df = ds.to_pandas()
19
+ except Exception:
20
+ df = _generate_synthetic(domain)
21
+
22
+ if sample_size and len(df) > sample_size:
23
+ df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
24
+ return df
25
+
26
+
27
+ def _generate_synthetic(domain: str, n: int = 200) -> pd.DataFrame:
28
+ """Generate synthetic data as fallback when HF dataset unavailable."""
29
+ config = DOMAINS[domain]
30
+ rng = np.random.default_rng(42)
31
+ data = {}
32
+
33
+ for col in config.columns:
34
+ if col in config.numeric_columns:
35
+ data[col] = rng.normal(50, 20, n).round(2)
36
+ elif col in config.categorical_columns:
37
+ categories = _get_categories(domain, col)
38
+ data[col] = rng.choice(categories, n).tolist()
39
+ elif "ID" in col:
40
+ data[col] = [f"{col[:3].upper()}-{i:04d}" for i in range(n)]
41
+ elif "Date" in col:
42
+ base = pd.Timestamp("2024-01-01")
43
+ data[col] = [(base + pd.Timedelta(days=int(d))).strftime("%Y-%m-%d")
44
+ for d in rng.integers(0, 365, n)]
45
+ elif "Name" in col or "Assignee" in col or "Rep" in col:
46
+ names = ["Alice", "Bob", "Carol", "Dan", "Eve", "Frank", "Grace", "Hank"]
47
+ data[col] = rng.choice(names, n).tolist()
48
+ else:
49
+ data[col] = [f"{col}_val_{i}" for i in range(n)]
50
+
51
+ return pd.DataFrame(data)
52
+
53
+
54
+ def _get_categories(domain: str, col: str) -> list[str]:
55
+ """Return realistic category values per domain and column."""
56
+ cat_map = {
57
+ "hr": {
58
+ "Department": ["Sales", "Research & Development", "Human Resources"],
59
+ "JobRole": ["Sales Executive", "Research Scientist", "Manager", "Lab Technician",
60
+ "Manufacturing Director", "Healthcare Representative"],
61
+ "Attrition": ["Yes", "No"],
62
+ "OverTime": ["Yes", "No"],
63
+ },
64
+ "sales": {
65
+ "Stage": ["Prospecting", "Qualification", "Proposal", "Negotiation", "Won", "Lost"],
66
+ "Region": ["East", "West", "Central", "North", "South"],
67
+ "Product": ["GTX Pro", "GTX Basic", "GTX Plus", "MG Special", "MG Advanced"],
68
+ "ForecastCategory": ["Pipeline", "Best Case", "Commit", "Closed"],
69
+ },
70
+ "pm": {
71
+ "Status": ["Not Started", "In Progress", "Completed", "On Hold", "Cancelled"],
72
+ "Priority": ["Critical", "High", "Medium", "Low"],
73
+ "RiskFlag": ["High", "Medium", "Low", "None"],
74
+ },
75
+ "it_ops": {
76
+ "Category": ["Hardware", "Software", "Network", "Access", "Email"],
77
+ "Priority": ["P1-Critical", "P2-High", "P3-Medium", "P4-Low"],
78
+ "Status": ["Open", "In Progress", "Resolved", "Closed", "Pending"],
79
+ "ResolutionType": ["Fix Applied", "Workaround", "No Fix", "Duplicate", "User Error"],
80
+ },
81
+ }
82
+ return cat_map.get(domain, {}).get(col, ["A", "B", "C"])
83
+
84
+
85
+ def inject_corruption(df: pd.DataFrame, domain_config: DomainConfig,
86
+ rate: float = 0.15) -> pd.DataFrame:
87
+ """Inject realistic data quality issues into a DataFrame."""
88
+ corrupted = df.copy()
89
+ n_rows = len(corrupted)
90
+ rng = np.random.default_rng(42)
91
+
92
+ # 1. Inject nulls into numeric columns
93
+ for col in domain_config.numeric_columns:
94
+ if col in corrupted.columns:
95
+ null_mask = rng.random(n_rows) < rate
96
+ corrupted.loc[null_mask, col] = np.nan
97
+
98
+ # 2. Inject type mismatches (strings in numeric columns)
99
+ for col in domain_config.numeric_columns:
100
+ if col in corrupted.columns:
101
+ n_bad = max(1, int(n_rows * rate * 0.3))
102
+ bad_idx = rng.choice(n_rows, n_bad, replace=False)
103
+ corrupted[col] = corrupted[col].astype(object)
104
+ for idx in bad_idx:
105
+ corrupted.iloc[idx, corrupted.columns.get_loc(col)] = rng.choice(
106
+ ["N/A", "unknown", "#REF!", "TBD", "-"]
107
+ )
108
+
109
+ # 3. Inject typos in categorical columns
110
+ for col in domain_config.categorical_columns:
111
+ if col in corrupted.columns:
112
+ n_typos = max(1, int(n_rows * rate * 0.2))
113
+ typo_idx = rng.choice(n_rows, n_typos, replace=False)
114
+ for idx in typo_idx:
115
+ val = str(corrupted.iloc[idx, corrupted.columns.get_loc(col)])
116
+ corrupted.iloc[idx, corrupted.columns.get_loc(col)] = _add_typo(val, rng)
117
+
118
+ # 4. Inject duplicates
119
+ n_dupes = max(1, int(n_rows * rate * 0.1))
120
+ dupe_idx = rng.choice(n_rows, n_dupes, replace=False)
121
+ dupes = corrupted.iloc[dupe_idx].copy()
122
+ corrupted = pd.concat([corrupted, dupes], ignore_index=True)
123
+
124
+ # 5. Inject whitespace issues
125
+ for col in domain_config.categorical_columns[:2]:
126
+ if col in corrupted.columns:
127
+ n_ws = max(1, int(n_rows * rate * 0.2))
128
+ ws_idx = rng.choice(len(corrupted), n_ws, replace=False)
129
+ for idx in ws_idx:
130
+ val = str(corrupted.iloc[idx, corrupted.columns.get_loc(col)])
131
+ corrupted.iloc[idx, corrupted.columns.get_loc(col)] = f" {val} "
132
+
133
+ return corrupted
134
+
135
+
136
+ def _add_typo(text: str, rng: np.random.Generator) -> str:
137
+ """Add a realistic typo to a string."""
138
+ if len(text) < 2:
139
+ return text
140
+ typo_type = rng.choice(["swap", "delete", "insert", "case"])
141
+ idx = rng.integers(0, len(text))
142
+ if typo_type == "swap" and idx < len(text) - 1:
143
+ return text[:idx] + text[idx + 1] + text[idx] + text[idx + 2:]
144
+ elif typo_type == "delete":
145
+ return text[:idx] + text[idx + 1:]
146
+ elif typo_type == "insert":
147
+ char = rng.choice(list(string.ascii_lowercase))
148
+ return text[:idx] + char + text[idx:]
149
+ else:
150
+ return text[:idx] + text[idx].swapcase() + text[idx + 1:]
151
+
152
+
153
+ def compute_dq_score(df: pd.DataFrame, domain_config: DomainConfig) -> dict:
154
+ """Compute data quality metrics: completeness, consistency, uniqueness, overall."""
155
+ available_cols = [c for c in domain_config.columns if c in df.columns]
156
+
157
+ # Completeness: 1 - (null ratio)
158
+ if available_cols:
159
+ null_ratio = df[available_cols].isnull().sum().sum() / (len(df) * len(available_cols))
160
+ completeness = 1.0 - null_ratio
161
+ else:
162
+ completeness = 1.0
163
+
164
+ # Consistency: check type correctness for numeric columns
165
+ consistency_scores = []
166
+ for col in domain_config.numeric_columns:
167
+ if col in df.columns:
168
+ valid = df[col].apply(lambda x: _is_numeric(x)).mean()
169
+ consistency_scores.append(valid)
170
+ consistency = float(np.mean(consistency_scores)) if consistency_scores else 1.0
171
+
172
+ # Uniqueness: 1 - (duplicate ratio)
173
+ if len(df) > 0:
174
+ n_dupes = df.duplicated(subset=available_cols[:5], keep='first').sum()
175
+ uniqueness = 1.0 - (n_dupes / len(df))
176
+ else:
177
+ uniqueness = 1.0
178
+
179
+ overall = 0.40 * completeness + 0.35 * consistency + 0.25 * uniqueness
180
+
181
+ return {
182
+ "completeness": round(completeness, 4),
183
+ "consistency": round(consistency, 4),
184
+ "uniqueness": round(uniqueness, 4),
185
+ "overall": round(overall, 4),
186
+ }
187
+
188
+
189
+ def _is_numeric(val) -> bool:
190
+ """Check if a value is numeric (or null, which is valid)."""
191
+ if pd.isna(val):
192
+ return True
193
+ try:
194
+ float(val)
195
+ return True
196
+ except (ValueError, TypeError):
197
+ return False
198
+
199
+
200
+ def compute_dq_score_with_lfs(df: pd.DataFrame, domain: str,
201
+ lfs: list) -> float:
202
+ """Compute DQ score using Snorkel-style labeling functions with majority vote."""
203
+ if not lfs or len(df) == 0:
204
+ config = DOMAINS.get(domain)
205
+ if config:
206
+ return compute_dq_score(df, config)["overall"]
207
+ return 0.5
208
+
209
+ ABSTAIN, BAD, GOOD = -1, 0, 1
210
+ row_scores = []
211
+
212
+ for _, row in df.iterrows():
213
+ votes = []
214
+ for lf in lfs:
215
+ try:
216
+ vote = lf(row)
217
+ if vote != ABSTAIN:
218
+ votes.append(vote)
219
+ except Exception:
220
+ continue
221
+ if votes:
222
+ row_scores.append(sum(v == GOOD for v in votes) / len(votes))
223
+ else:
224
+ row_scores.append(0.5)
225
+
226
+ return float(np.mean(row_scores))
227
+
228
+
229
+ def format_preview(df: pd.DataFrame, n: int = 5) -> str:
230
+ """Format first n rows as a text table."""
231
+ return df.head(n).to_string(index=False, max_colwidth=30)
232
+
233
+
234
+ def format_columns_info(df: pd.DataFrame, domain_config: DomainConfig) -> str:
235
+ """Format column info: name, dtype, null count."""
236
+ lines = []
237
+ for col in df.columns:
238
+ null_count = df[col].isnull().sum()
239
+ dtype = str(df[col].dtype)
240
+ expected = "expected" if col in domain_config.columns else "extra"
241
+ lines.append(f"{col}: {dtype}, nulls={null_count} ({expected})")
242
+ return "\n".join(lines)
environments/shared/personas.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generalized personas for domain-independent question answering."""
2
+
3
+ import re
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class Persona(BaseModel):
9
+ name: str
10
+ role: str
11
+ focus_areas: list[str]
12
+ language_style: str
13
+ keywords: list[str]
14
+ anti_keywords: list[str]
15
+
16
+
17
+ PERSONAS = [
18
+ Persona(
19
+ name="Executive",
20
+ role="executive",
21
+ focus_areas=["costs", "ROI", "strategic risk", "portfolio trends", "year-over-year"],
22
+ language_style="strategic-financial",
23
+ keywords=["revenue", "cost", "ROI", "risk", "trend", "quarter",
24
+ "year-over-year", "impact", "budget", "margin", "growth"],
25
+ anti_keywords=["I think", "maybe", "um", "idk"],
26
+ ),
27
+ Persona(
28
+ name="Manager",
29
+ role="manager",
30
+ focus_areas=["team performance", "operational health", "process bottlenecks", "capacity"],
31
+ language_style="operational-actionable",
32
+ keywords=["team", "performance", "bottleneck", "capacity", "SLA",
33
+ "process", "action", "priority", "escalation", "delivery"],
34
+ anti_keywords=["shareholder", "valuation", "IPO"],
35
+ ),
36
+ Persona(
37
+ name="Individual Contributor",
38
+ role="ic",
39
+ focus_areas=["personal tasks", "deadlines", "what to do next", "simple explanations"],
40
+ language_style="plain-personal",
41
+ keywords=["my", "I should", "next step", "deadline", "help",
42
+ "understand", "priority", "task", "assigned"],
43
+ anti_keywords=["KPI", "ROI", "portfolio", "strategic", "EBITDA"],
44
+ ),
45
+ ]
46
+
47
+ PERSONA_MAP = {p.role: p for p in PERSONAS}
48
+
49
+
50
+ def get_persona(role: str) -> Persona:
51
+ """Get a persona by role name."""
52
+ return PERSONA_MAP[role]
53
+
54
+
55
+ def score_persona_alignment(answer: str, persona: Persona) -> float:
56
+ """Score how well an answer aligns with a persona's communication style.
57
+
58
+ Returns a float 0-1 based on:
59
+ - Keyword density (presence of expected terms)
60
+ - Anti-keyword penalty (presence of terms that don't fit)
61
+ - Formality check (matches language style)
62
+ """
63
+ answer_lower = answer.lower()
64
+ words = re.findall(r'\w+', answer_lower)
65
+ word_count = max(len(words), 1)
66
+
67
+ # Keyword scoring: fraction of persona keywords found
68
+ keyword_hits = sum(1 for kw in persona.keywords if kw.lower() in answer_lower)
69
+ keyword_score = min(keyword_hits / max(len(persona.keywords) * 0.3, 1), 1.0)
70
+
71
+ # Anti-keyword penalty
72
+ anti_hits = sum(1 for akw in persona.anti_keywords if akw.lower() in answer_lower)
73
+ anti_penalty = min(anti_hits * 0.15, 0.5)
74
+
75
+ # Formality check
76
+ formality_score = _check_formality(answer, persona.language_style)
77
+
78
+ # Combine: 50% keywords, 20% formality, 30% base (minus anti-penalty)
79
+ raw_score = 0.50 * keyword_score + 0.20 * formality_score + 0.30 - anti_penalty
80
+ return round(max(0.0, min(1.0, raw_score)), 4)
81
+
82
+
83
+ def _check_formality(text: str, style: str) -> float:
84
+ """Check if text formality matches the expected language style."""
85
+ text_lower = text.lower()
86
+
87
+ if style == "strategic-financial":
88
+ indicators = ["percent", "%", "million", "billion", "quarter", "fiscal",
89
+ "forecast", "benchmark", "metric"]
90
+ hits = sum(1 for ind in indicators if ind in text_lower)
91
+ return min(hits / 2.0, 1.0)
92
+
93
+ elif style == "operational-actionable":
94
+ indicators = ["action", "recommend", "should", "priority", "next steps",
95
+ "immediate", "plan", "schedule"]
96
+ hits = sum(1 for ind in indicators if ind in text_lower)
97
+ return min(hits / 2.0, 1.0)
98
+
99
+ elif style == "plain-personal":
100
+ # Plain style rewards shorter sentences and simple language
101
+ sentences = text.split(".")
102
+ avg_len = sum(len(s.split()) for s in sentences) / max(len(sentences), 1)
103
+ return 1.0 if avg_len < 20 else max(0.0, 1.0 - (avg_len - 20) / 30)
104
+
105
+ return 0.5
environments/shared/reward_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward computation with cached downstream signals."""
2
+
3
+ # Cached downstream signals mapping quality buckets to historical scores.
4
+ # These represent how well downstream stages perform given upstream quality.
5
+ DOWNSTREAM_CACHE: dict[str, float] = {
6
+ "excellent": 0.95, # DQ > 0.90 or coverage > 0.80
7
+ "good": 0.75, # DQ 0.70-0.90 or coverage 0.50-0.80
8
+ "fair": 0.50, # DQ 0.50-0.70 or coverage 0.30-0.50
9
+ "poor": 0.20, # DQ < 0.50 or coverage < 0.30
10
+ }
11
+
12
+
13
+ def _get_downstream_bucket(score: float) -> str:
14
+ """Map a score to a downstream quality bucket."""
15
+ if score > 0.90:
16
+ return "excellent"
17
+ elif score > 0.70:
18
+ return "good"
19
+ elif score > 0.50:
20
+ return "fair"
21
+ return "poor"
22
+
23
+
24
+ def cleaning_reward(dq_score: float, downstream_bucket: str = "") -> float:
25
+ """Compute cleaning stage reward.
26
+
27
+ 0.70 * dq_score + 0.30 * downstream_signal
28
+ """
29
+ if not downstream_bucket:
30
+ downstream_bucket = _get_downstream_bucket(dq_score)
31
+ downstream = DOWNSTREAM_CACHE.get(downstream_bucket, 0.5)
32
+ return round(0.70 * dq_score + 0.30 * downstream, 4)
33
+
34
+
35
+ def enrichment_reward(coverage: float, downstream_bucket: str = "") -> float:
36
+ """Compute enrichment stage reward.
37
+
38
+ 0.50 * coverage + 0.50 * downstream_signal
39
+ """
40
+ if not downstream_bucket:
41
+ downstream_bucket = _get_downstream_bucket(coverage)
42
+ downstream = DOWNSTREAM_CACHE.get(downstream_bucket, 0.5)
43
+ return round(0.50 * coverage + 0.50 * downstream, 4)
44
+
45
+
46
+ def answering_reward(faithfulness: float, persona_relevance: float,
47
+ patronus_score: float | None = None) -> float:
48
+ """Compute answering stage reward.
49
+
50
+ Without Patronus: 0.30 * faithfulness + 0.70 * persona_relevance
51
+ With Patronus: 0.40 * patronus_faithfulness + 0.60 * persona_relevance
52
+ """
53
+ if patronus_score is not None:
54
+ return round(0.40 * patronus_score + 0.60 * persona_relevance, 4)
55
+ return round(0.30 * faithfulness + 0.70 * persona_relevance, 4)
models.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models for the DataSage Cleaning Environment."""
2
+
3
+ from openenv.core.env_server.types import Action, Observation
4
+ from pydantic import Field
5
+ from typing import Optional
6
+
7
+
8
+ class CleaningAction(Action):
9
+ """Action for the Cleaning environment - a data cleaning operation."""
10
+
11
+ operation: str = Field(
12
+ ...,
13
+ description="Cleaning operation: fill_null|fix_type|remove_duplicate|standardize|trim|correct_typo",
14
+ )
15
+ column: str = Field(..., description="Target column name")
16
+ value: Optional[str] = Field(
17
+ None,
18
+ description="Replacement value or rule (e.g., 'median', 'mode', a specific value)",
19
+ )
20
+ params: dict = Field(default_factory=dict)
21
+
22
+
23
+ class CleaningObservation(Observation):
24
+ """Observation from the Cleaning environment - data quality state."""
25
+
26
+ domain: str = Field(default="", description="Current domain: hr|sales|pm|it_ops")
27
+ data_preview: str = Field(default="", description="First 5 rows as text table")
28
+ dq_report: str = Field(
29
+ default="",
30
+ description="Completeness, consistency, uniqueness breakdown",
31
+ )
32
+ dq_score: float = Field(default=0.0, description="Overall DQ score 0-1")
33
+ columns_info: str = Field(
34
+ default="", description="Column names, types, null counts"
35
+ )
36
+ step_number: int = Field(default=0)
37
+ max_steps: int = Field(default=15)
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: datasage_cleaning
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
pyproject.toml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-datasage-cleaning"
7
+ version = "0.1.0"
8
+ description = "DataSage Cleaning environment for OpenEnv"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "openenv-core[core]>=0.2.1",
12
+ "pandas>=2.0",
13
+ "numpy>=1.24",
14
+ ]
15
+
16
+ [project.optional-dependencies]
17
+ dev = [
18
+ "pytest>=8.0.0",
19
+ "pytest-cov>=4.0.0",
20
+ ]
21
+
22
+ [project.scripts]
23
+ server = "datasage_cleaning.server.app:main"
24
+
25
+ [tool.setuptools]
26
+ include-package-data = true
27
+ packages = ["datasage_cleaning", "datasage_cleaning.server"]
28
+ package-dir = { "datasage_cleaning" = ".", "datasage_cleaning.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Cleaning environment server components."""
2
+
3
+ from .cleaning_environment import CleaningEnvironment
4
+
5
+ __all__ = ["CleaningEnvironment"]
server/app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for the DataSage Cleaning Environment.
3
+
4
+ Endpoints:
5
+ - POST /reset: Reset the environment
6
+ - POST /step: Execute an action
7
+ - GET /state: Get current environment state
8
+ - GET /schema: Get action/observation schemas
9
+ - WS /ws: WebSocket endpoint for persistent sessions
10
+
11
+ Usage:
12
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
13
+ """
14
+
15
+ try:
16
+ from openenv.core.env_server.http_server import create_app
17
+ except Exception as e: # pragma: no cover
18
+ raise ImportError(
19
+ "openenv is required. Install with: uv sync"
20
+ ) from e
21
+
22
+ from models import CleaningAction, CleaningObservation
23
+ from .cleaning_environment import CleaningEnvironment
24
+
25
+
26
+ app = create_app(
27
+ CleaningEnvironment,
28
+ CleaningAction,
29
+ CleaningObservation,
30
+ env_name="datasage_cleaning",
31
+ max_concurrent_envs=4,
32
+ )
33
+
34
+
35
+ def main(host: str = "0.0.0.0", port: int = 8000):
36
+ """Entry point for direct execution."""
37
+ import uvicorn
38
+
39
+ uvicorn.run(app, host=host, port=port)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ import argparse
44
+
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--port", type=int, default=8000)
47
+ args = parser.parse_args()
48
+ main(port=args.port)
server/cleaning_environment.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataSage Cleaning Environment Implementation.
3
+
4
+ An RL environment where the agent must clean corrupted enterprise data
5
+ across 4 domains (HR, Sales, PM, IT Ops) by applying cleaning operations
6
+ to maximise the data quality score.
7
+ """
8
+
9
+ import random
10
+ import sys
11
+ import os
12
+
13
+ from uuid import uuid4
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ # Allow imports from the project root so shared modules are reachable.
19
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
20
+
21
+ from environments.shared.domains import DOMAINS
22
+ from environments.shared.enterprise_data import (
23
+ load_domain_data,
24
+ inject_corruption,
25
+ compute_dq_score,
26
+ format_preview,
27
+ format_columns_info,
28
+ )
29
+ from environments.shared.reward_utils import cleaning_reward
30
+
31
+ from models import CleaningAction, CleaningObservation
32
+ from openenv.core.env_server.interfaces import Environment
33
+ from openenv.core.env_server.types import State
34
+
35
+
36
+ class CleaningEnvironment(Environment):
37
+ """
38
+ Cleaning environment: the agent fixes data quality issues in a
39
+ 50-row enterprise data batch.
40
+
41
+ Supported operations:
42
+ fill_null - fill missing values (median / mode / explicit value)
43
+ fix_type - cast a column to numeric, coercing errors to NaN
44
+ remove_duplicate - drop duplicate rows
45
+ standardize - strip whitespace and normalise case (lower / title)
46
+ trim - strip leading/trailing whitespace
47
+ correct_typo - replace a typo with a correct value
48
+ """
49
+
50
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
51
+
52
+ def __init__(self):
53
+ """Initialise the cleaning environment."""
54
+ self._state = State(episode_id=str(uuid4()), step_count=0)
55
+ self._df: pd.DataFrame = pd.DataFrame()
56
+ self._domain_name: str = ""
57
+ self._domain_config = None
58
+ self._max_steps: int = 15
59
+
60
+ # ------------------------------------------------------------------
61
+ # reset
62
+ # ------------------------------------------------------------------
63
+ def reset(self) -> CleaningObservation:
64
+ """Pick a random domain, load a 50-row batch, inject corruption."""
65
+ self._state = State(episode_id=str(uuid4()), step_count=0)
66
+
67
+ self._domain_name = random.choice(list(DOMAINS.keys()))
68
+ self._domain_config = DOMAINS[self._domain_name]
69
+
70
+ # Load raw data and sample 50 rows
71
+ raw_df = load_domain_data(self._domain_name, sample_size=50)
72
+ self._df = inject_corruption(raw_df, self._domain_config, rate=0.15)
73
+
74
+ dq = compute_dq_score(self._df, self._domain_config)
75
+ dq_report = (
76
+ f"completeness={dq['completeness']:.4f} "
77
+ f"consistency={dq['consistency']:.4f} "
78
+ f"uniqueness={dq['uniqueness']:.4f}"
79
+ )
80
+
81
+ return CleaningObservation(
82
+ domain=self._domain_name,
83
+ data_preview=format_preview(self._df),
84
+ dq_report=dq_report,
85
+ dq_score=dq["overall"],
86
+ columns_info=format_columns_info(self._df, self._domain_config),
87
+ step_number=0,
88
+ max_steps=self._max_steps,
89
+ done=False,
90
+ reward=0.0,
91
+ )
92
+
93
+ # ------------------------------------------------------------------
94
+ # step
95
+ # ------------------------------------------------------------------
96
+ def step(self, action: CleaningAction) -> CleaningObservation: # type: ignore[override]
97
+ """Apply a single cleaning operation and return the updated state."""
98
+ self._state.step_count += 1
99
+ step = self._state.step_count
100
+
101
+ op = action.operation
102
+ col = action.column
103
+ value = action.value
104
+
105
+ # ---- apply operation ----
106
+ try:
107
+ if op == "fill_null":
108
+ self._apply_fill_null(col, value)
109
+ elif op == "fix_type":
110
+ self._apply_fix_type(col)
111
+ elif op == "remove_duplicate":
112
+ self._apply_remove_duplicate()
113
+ elif op == "standardize":
114
+ self._apply_standardize(col, value)
115
+ elif op == "trim":
116
+ self._apply_trim(col)
117
+ elif op == "correct_typo":
118
+ self._apply_correct_typo(col, value, action.params)
119
+ # unknown ops are silently ignored (no crash)
120
+ except Exception:
121
+ pass # invalid column, etc. -> no-op
122
+
123
+ # ---- compute DQ and reward ----
124
+ dq = compute_dq_score(self._df, self._domain_config)
125
+ reward = cleaning_reward(dq["overall"])
126
+ done = dq["overall"] > 0.95 or step >= self._max_steps
127
+
128
+ dq_report = (
129
+ f"completeness={dq['completeness']:.4f} "
130
+ f"consistency={dq['consistency']:.4f} "
131
+ f"uniqueness={dq['uniqueness']:.4f}"
132
+ )
133
+
134
+ return CleaningObservation(
135
+ domain=self._domain_name,
136
+ data_preview=format_preview(self._df),
137
+ dq_report=dq_report,
138
+ dq_score=dq["overall"],
139
+ columns_info=format_columns_info(self._df, self._domain_config),
140
+ step_number=step,
141
+ max_steps=self._max_steps,
142
+ done=done,
143
+ reward=reward,
144
+ metadata={
145
+ "operation": op,
146
+ "column": col,
147
+ "step": step,
148
+ },
149
+ )
150
+
151
+ # ------------------------------------------------------------------
152
+ # state property
153
+ # ------------------------------------------------------------------
154
+ @property
155
+ def state(self) -> State:
156
+ """Return current environment state."""
157
+ return self._state
158
+
159
+ # ------------------------------------------------------------------
160
+ # operation helpers
161
+ # ------------------------------------------------------------------
162
+ def _apply_fill_null(self, col: str, value: str | None) -> None:
163
+ if col not in self._df.columns:
164
+ return
165
+ if value == "median":
166
+ numeric = pd.to_numeric(self._df[col], errors="coerce")
167
+ fill_val = numeric.median()
168
+ elif value == "mode":
169
+ mode_vals = self._df[col].mode()
170
+ fill_val = mode_vals.iloc[0] if len(mode_vals) > 0 else ""
171
+ else:
172
+ fill_val = value if value is not None else ""
173
+ self._df[col] = self._df[col].fillna(fill_val)
174
+
175
+ def _apply_fix_type(self, col: str) -> None:
176
+ if col not in self._df.columns:
177
+ return
178
+ self._df[col] = pd.to_numeric(self._df[col], errors="coerce")
179
+
180
+ def _apply_remove_duplicate(self) -> None:
181
+ available = [c for c in self._domain_config.columns if c in self._df.columns]
182
+ self._df = self._df.drop_duplicates(
183
+ subset=available[:5], keep="first"
184
+ ).reset_index(drop=True)
185
+
186
+ def _apply_standardize(self, col: str, value: str | None) -> None:
187
+ if col not in self._df.columns:
188
+ return
189
+ self._df[col] = self._df[col].astype(str).str.strip()
190
+ if value == "lower":
191
+ self._df[col] = self._df[col].str.lower()
192
+ elif value == "title":
193
+ self._df[col] = self._df[col].str.title()
194
+
195
+ def _apply_trim(self, col: str) -> None:
196
+ if col not in self._df.columns:
197
+ return
198
+ self._df[col] = self._df[col].astype(str).str.strip()
199
+
200
+ def _apply_correct_typo(self, col: str, value: str | None,
201
+ params: dict) -> None:
202
+ if col not in self._df.columns or value is None:
203
+ return
204
+ wrong = params.get("wrong")
205
+ if wrong:
206
+ self._df[col] = self._df[col].replace(wrong, value)
207
+ else:
208
+ # If no specific wrong value given, try to replace the most
209
+ # uncommon value with the provided correct value.
210
+ counts = self._df[col].value_counts()
211
+ if len(counts) > 1:
212
+ least_common = counts.index[-1]
213
+ self._df[col] = self._df[col].replace(least_common, value)
server/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ pandas>=2.0
5
+ numpy>=1.24
uv.lock ADDED
The diff for this file is too large to render. See raw diff