Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- Dockerfile +51 -0
- README.md +29 -5
- __init__.py +1 -0
- client.py +62 -0
- environments/__init__.py +1 -0
- environments/shared/__init__.py +11 -0
- environments/shared/domains.py +110 -0
- environments/shared/enrichment_sources.py +363 -0
- environments/shared/enterprise_data.py +242 -0
- environments/shared/personas.py +105 -0
- environments/shared/reward_utils.py +55 -0
- models.py +37 -0
- openenv.yaml +6 -0
- pyproject.toml +28 -0
- server/__init__.py +5 -0
- server/app.py +48 -0
- server/cleaning_environment.py +213 -0
- server/requirements.txt +5 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
+
FROM ${BASE_IMAGE} AS builder
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
RUN apt-get update && \
|
| 7 |
+
apt-get install -y --no-install-recommends git && \
|
| 8 |
+
rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
ARG BUILD_MODE=in-repo
|
| 11 |
+
ARG ENV_NAME=datasage_cleaning
|
| 12 |
+
|
| 13 |
+
COPY . /app/env
|
| 14 |
+
|
| 15 |
+
WORKDIR /app/env
|
| 16 |
+
|
| 17 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 18 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 19 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 20 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 24 |
+
if [ -f uv.lock ]; then \
|
| 25 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 26 |
+
else \
|
| 27 |
+
uv sync --no-install-project --no-editable; \
|
| 28 |
+
fi
|
| 29 |
+
|
| 30 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 31 |
+
if [ -f uv.lock ]; then \
|
| 32 |
+
uv sync --frozen --no-editable; \
|
| 33 |
+
else \
|
| 34 |
+
uv sync --no-editable; \
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
FROM ${BASE_IMAGE}
|
| 38 |
+
|
| 39 |
+
WORKDIR /app
|
| 40 |
+
|
| 41 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 42 |
+
COPY --from=builder /app/env /app/env
|
| 43 |
+
|
| 44 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 45 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 46 |
+
|
| 47 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 48 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 49 |
+
|
| 50 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 51 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,34 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DataSage Cleaning Environment
|
| 3 |
+
emoji: 🧹
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# DataSage Cleaning Environment
|
| 15 |
+
|
| 16 |
+
An RL environment for training agents to clean enterprise data across four domains (HR, Sales, Project Management, IT Operations).
|
| 17 |
+
|
| 18 |
+
The agent receives a corrupted 50-row data batch and must apply cleaning operations (fill nulls, fix types, remove duplicates, standardize values, trim whitespace, correct typos) to maximise a composite data quality score. Episodes end when DQ > 0.95 or after 15 steps.
|
| 19 |
+
|
| 20 |
+
## Quick Start
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from environments.cleaning.models import CleaningAction
|
| 24 |
+
from environments.cleaning.client import CleaningEnv
|
| 25 |
+
|
| 26 |
+
with CleaningEnv(base_url="http://localhost:8000") as env:
|
| 27 |
+
result = env.reset()
|
| 28 |
+
print(f"Domain: {result.observation.domain}, DQ: {result.observation.dq_score}")
|
| 29 |
+
|
| 30 |
+
result = env.step(CleaningAction(
|
| 31 |
+
operation="fill_null", column="Age", value="median"
|
| 32 |
+
))
|
| 33 |
+
print(f"DQ after step: {result.observation.dq_score}")
|
| 34 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""DataSage Cleaning Environment."""
|
client.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DataSage Cleaning Environment Client."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from openenv.core import EnvClient
|
| 6 |
+
from openenv.core.client_types import StepResult
|
| 7 |
+
from openenv.core.env_server.types import State
|
| 8 |
+
|
| 9 |
+
from .models import CleaningAction, CleaningObservation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CleaningEnv(EnvClient[CleaningAction, CleaningObservation, State]):
|
| 13 |
+
"""
|
| 14 |
+
Client for the DataSage Cleaning Environment.
|
| 15 |
+
|
| 16 |
+
Example:
|
| 17 |
+
>>> with CleaningEnv(base_url="http://localhost:8000") as client:
|
| 18 |
+
... result = client.reset()
|
| 19 |
+
... print(result.observation.dq_score)
|
| 20 |
+
... result = client.step(CleaningAction(
|
| 21 |
+
... operation="fill_null", column="Age", value="median"
|
| 22 |
+
... ))
|
| 23 |
+
... print(result.observation.dq_score)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def _step_payload(self, action: CleaningAction) -> Dict:
|
| 27 |
+
"""Convert CleaningAction to JSON payload."""
|
| 28 |
+
return {
|
| 29 |
+
"operation": action.operation,
|
| 30 |
+
"column": action.column,
|
| 31 |
+
"value": action.value,
|
| 32 |
+
"params": action.params,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def _parse_result(self, payload: Dict) -> StepResult[CleaningObservation]:
|
| 36 |
+
"""Parse server response into StepResult[CleaningObservation]."""
|
| 37 |
+
obs_data = payload.get("observation", {})
|
| 38 |
+
observation = CleaningObservation(
|
| 39 |
+
domain=obs_data.get("domain", ""),
|
| 40 |
+
data_preview=obs_data.get("data_preview", ""),
|
| 41 |
+
dq_report=obs_data.get("dq_report", ""),
|
| 42 |
+
dq_score=obs_data.get("dq_score", 0.0),
|
| 43 |
+
columns_info=obs_data.get("columns_info", ""),
|
| 44 |
+
step_number=obs_data.get("step_number", 0),
|
| 45 |
+
max_steps=obs_data.get("max_steps", 15),
|
| 46 |
+
done=payload.get("done", False),
|
| 47 |
+
reward=payload.get("reward"),
|
| 48 |
+
metadata=obs_data.get("metadata", {}),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return StepResult(
|
| 52 |
+
observation=observation,
|
| 53 |
+
reward=payload.get("reward"),
|
| 54 |
+
done=payload.get("done", False),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 58 |
+
"""Parse server response into State object."""
|
| 59 |
+
return State(
|
| 60 |
+
episode_id=payload.get("episode_id"),
|
| 61 |
+
step_count=payload.get("step_count", 0),
|
| 62 |
+
)
|
environments/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# DataSage environments package
|
environments/shared/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared utilities for DataSage multi-domain enterprise environments."""
|
| 2 |
+
|
| 3 |
+
from .domains import DOMAINS, DomainConfig
|
| 4 |
+
from .personas import PERSONAS, Persona, score_persona_alignment
|
| 5 |
+
from .reward_utils import cleaning_reward, enrichment_reward, answering_reward
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"DOMAINS", "DomainConfig",
|
| 9 |
+
"PERSONAS", "Persona", "score_persona_alignment",
|
| 10 |
+
"cleaning_reward", "enrichment_reward", "answering_reward",
|
| 11 |
+
]
|
environments/shared/domains.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Domain registry for the 4 enterprise data domains."""
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DomainConfig(BaseModel):
|
| 7 |
+
name: str
|
| 8 |
+
display_name: str
|
| 9 |
+
dataset_key: str
|
| 10 |
+
columns: list[str]
|
| 11 |
+
numeric_columns: list[str]
|
| 12 |
+
categorical_columns: list[str]
|
| 13 |
+
possible_enrichments: list[str]
|
| 14 |
+
example_questions: list[str]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DOMAINS = {
|
| 18 |
+
"hr": DomainConfig(
|
| 19 |
+
name="hr",
|
| 20 |
+
display_name="HR & People",
|
| 21 |
+
dataset_key="hr",
|
| 22 |
+
columns=[
|
| 23 |
+
"EmployeeID", "Age", "Department", "JobRole", "MonthlyIncome",
|
| 24 |
+
"YearsAtCompany", "Attrition", "JobSatisfaction", "OverTime",
|
| 25 |
+
"DistanceFromHome", "Education", "PerformanceRating",
|
| 26 |
+
],
|
| 27 |
+
numeric_columns=["Age", "MonthlyIncome", "YearsAtCompany", "DistanceFromHome"],
|
| 28 |
+
categorical_columns=["Department", "JobRole", "Attrition", "OverTime"],
|
| 29 |
+
possible_enrichments=[
|
| 30 |
+
"salary_band", "tenure_risk", "satisfaction_index",
|
| 31 |
+
"industry_benchmark", "flight_risk_score",
|
| 32 |
+
],
|
| 33 |
+
example_questions=[
|
| 34 |
+
"Which departments have the highest attrition rates?",
|
| 35 |
+
"What factors correlate most with employee turnover?",
|
| 36 |
+
"How does overtime affect job satisfaction?",
|
| 37 |
+
"What is the salary distribution across job roles?",
|
| 38 |
+
"Which employees are at highest flight risk?",
|
| 39 |
+
],
|
| 40 |
+
),
|
| 41 |
+
"sales": DomainConfig(
|
| 42 |
+
name="sales",
|
| 43 |
+
display_name="Sales & Revenue",
|
| 44 |
+
dataset_key="sales",
|
| 45 |
+
columns=[
|
| 46 |
+
"DealID", "AccountName", "Stage", "Amount", "CloseDate",
|
| 47 |
+
"Rep", "Product", "Region", "LeadSource", "DaysInStage",
|
| 48 |
+
"Probability", "ForecastCategory",
|
| 49 |
+
],
|
| 50 |
+
numeric_columns=["Amount", "DaysInStage", "Probability"],
|
| 51 |
+
categorical_columns=["Stage", "Region", "Product", "ForecastCategory"],
|
| 52 |
+
possible_enrichments=[
|
| 53 |
+
"deal_size_category", "velocity_score", "win_probability_model",
|
| 54 |
+
"industry_code", "competitive_risk",
|
| 55 |
+
],
|
| 56 |
+
example_questions=[
|
| 57 |
+
"What's our pipeline health for this quarter?",
|
| 58 |
+
"Which deals are at risk of slipping?",
|
| 59 |
+
"What's the average deal velocity by region?",
|
| 60 |
+
"Which reps are below quota?",
|
| 61 |
+
"What's the conversion rate by lead source?",
|
| 62 |
+
],
|
| 63 |
+
),
|
| 64 |
+
"pm": DomainConfig(
|
| 65 |
+
name="pm",
|
| 66 |
+
display_name="Project Management",
|
| 67 |
+
dataset_key="pm",
|
| 68 |
+
columns=[
|
| 69 |
+
"TaskID", "ProjectName", "Assignee", "Status", "Priority",
|
| 70 |
+
"DueDate", "EstimatedHours", "ActualHours", "Dependencies",
|
| 71 |
+
"Milestone", "RiskFlag", "CompletionPct",
|
| 72 |
+
],
|
| 73 |
+
numeric_columns=["EstimatedHours", "ActualHours", "CompletionPct"],
|
| 74 |
+
categorical_columns=["Status", "Priority", "RiskFlag"],
|
| 75 |
+
possible_enrichments=[
|
| 76 |
+
"schedule_risk_score", "resource_utilization",
|
| 77 |
+
"dependency_chain_depth", "burndown_rate", "delay_probability",
|
| 78 |
+
],
|
| 79 |
+
example_questions=[
|
| 80 |
+
"Which projects are at risk of missing deadlines?",
|
| 81 |
+
"How is resource utilization across teams?",
|
| 82 |
+
"What's the burndown rate for the current sprint?",
|
| 83 |
+
"Which tasks are blocking the most downstream work?",
|
| 84 |
+
"What's our on-time delivery rate?",
|
| 85 |
+
],
|
| 86 |
+
),
|
| 87 |
+
"it_ops": DomainConfig(
|
| 88 |
+
name="it_ops",
|
| 89 |
+
display_name="IT Operations",
|
| 90 |
+
dataset_key="it_ops",
|
| 91 |
+
columns=[
|
| 92 |
+
"TicketID", "Category", "Priority", "Status", "Assignee",
|
| 93 |
+
"CreatedDate", "ResolvedDate", "SLATarget", "EscalationLevel",
|
| 94 |
+
"AffectedSystem", "ResolutionType", "CustomerImpact",
|
| 95 |
+
],
|
| 96 |
+
numeric_columns=["SLATarget", "EscalationLevel"],
|
| 97 |
+
categorical_columns=["Category", "Priority", "Status", "ResolutionType"],
|
| 98 |
+
possible_enrichments=[
|
| 99 |
+
"sla_compliance_flag", "mttr_band", "escalation_path",
|
| 100 |
+
"incident_severity_score", "recurring_pattern_flag",
|
| 101 |
+
],
|
| 102 |
+
example_questions=[
|
| 103 |
+
"What's our SLA compliance rate this month?",
|
| 104 |
+
"Which systems have the most incidents?",
|
| 105 |
+
"What's the mean time to resolution trend?",
|
| 106 |
+
"How many tickets are breaching SLA?",
|
| 107 |
+
"What are the most common root causes?",
|
| 108 |
+
],
|
| 109 |
+
),
|
| 110 |
+
}
|
environments/shared/enrichment_sources.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Static enrichment lookup tables per domain (no API calls)."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
# Enrichment registry: domain -> source -> lookup function or static data
|
| 6 |
+
ENRICHMENT_REGISTRY: dict[str, dict[str, dict]] = {
|
| 7 |
+
"hr": {
|
| 8 |
+
"salary_band": {
|
| 9 |
+
"description": "BLS salary band classification based on monthly income",
|
| 10 |
+
"type": "derived",
|
| 11 |
+
"logic": "classify_salary_band",
|
| 12 |
+
},
|
| 13 |
+
"tenure_risk": {
|
| 14 |
+
"description": "Tenure-based flight risk score",
|
| 15 |
+
"type": "derived",
|
| 16 |
+
"logic": "compute_tenure_risk",
|
| 17 |
+
},
|
| 18 |
+
"satisfaction_index": {
|
| 19 |
+
"description": "Composite satisfaction index from multiple factors",
|
| 20 |
+
"type": "derived",
|
| 21 |
+
"logic": "compute_satisfaction_index",
|
| 22 |
+
},
|
| 23 |
+
"industry_benchmark": {
|
| 24 |
+
"description": "Industry benchmark salary percentile",
|
| 25 |
+
"type": "lookup",
|
| 26 |
+
"data": {
|
| 27 |
+
"Sales Executive": 65000, "Research Scientist": 72000,
|
| 28 |
+
"Manager": 85000, "Lab Technician": 45000,
|
| 29 |
+
"Manufacturing Director": 95000, "Healthcare Representative": 55000,
|
| 30 |
+
"Human Resources": 60000,
|
| 31 |
+
},
|
| 32 |
+
},
|
| 33 |
+
"flight_risk_score": {
|
| 34 |
+
"description": "Combined flight risk from satisfaction, tenure, overtime",
|
| 35 |
+
"type": "derived",
|
| 36 |
+
"logic": "compute_flight_risk",
|
| 37 |
+
},
|
| 38 |
+
},
|
| 39 |
+
"sales": {
|
| 40 |
+
"deal_size_category": {
|
| 41 |
+
"description": "Categorize deal by amount: Small/Medium/Large/Enterprise",
|
| 42 |
+
"type": "derived",
|
| 43 |
+
"logic": "classify_deal_size",
|
| 44 |
+
},
|
| 45 |
+
"velocity_score": {
|
| 46 |
+
"description": "Deal velocity based on days in stage vs benchmark",
|
| 47 |
+
"type": "derived",
|
| 48 |
+
"logic": "compute_velocity_score",
|
| 49 |
+
},
|
| 50 |
+
"win_probability_model": {
|
| 51 |
+
"description": "Heuristic win probability based on stage + days",
|
| 52 |
+
"type": "derived",
|
| 53 |
+
"logic": "compute_win_probability",
|
| 54 |
+
},
|
| 55 |
+
"industry_code": {
|
| 56 |
+
"description": "Industry classification code from account name patterns",
|
| 57 |
+
"type": "lookup",
|
| 58 |
+
"data": {
|
| 59 |
+
"Tech": "SIC-7372", "Healthcare": "SIC-8011",
|
| 60 |
+
"Finance": "SIC-6020", "Retail": "SIC-5311",
|
| 61 |
+
"Manufacturing": "SIC-3559", "default": "SIC-9999",
|
| 62 |
+
},
|
| 63 |
+
},
|
| 64 |
+
"competitive_risk": {
|
| 65 |
+
"description": "Competitive risk score based on deal stage and velocity",
|
| 66 |
+
"type": "derived",
|
| 67 |
+
"logic": "compute_competitive_risk",
|
| 68 |
+
},
|
| 69 |
+
},
|
| 70 |
+
"pm": {
|
| 71 |
+
"schedule_risk_score": {
|
| 72 |
+
"description": "Risk of schedule slippage based on progress vs due date",
|
| 73 |
+
"type": "derived",
|
| 74 |
+
"logic": "compute_schedule_risk",
|
| 75 |
+
},
|
| 76 |
+
"resource_utilization": {
|
| 77 |
+
"description": "Resource utilization ratio: actual/estimated hours",
|
| 78 |
+
"type": "derived",
|
| 79 |
+
"logic": "compute_resource_utilization",
|
| 80 |
+
},
|
| 81 |
+
"dependency_chain_depth": {
|
| 82 |
+
"description": "Depth of dependency chain for task",
|
| 83 |
+
"type": "derived",
|
| 84 |
+
"logic": "compute_dependency_depth",
|
| 85 |
+
},
|
| 86 |
+
"burndown_rate": {
|
| 87 |
+
"description": "Task completion rate relative to plan",
|
| 88 |
+
"type": "derived",
|
| 89 |
+
"logic": "compute_burndown_rate",
|
| 90 |
+
},
|
| 91 |
+
"delay_probability": {
|
| 92 |
+
"description": "Probability of delay based on current trajectory",
|
| 93 |
+
"type": "derived",
|
| 94 |
+
"logic": "compute_delay_probability",
|
| 95 |
+
},
|
| 96 |
+
},
|
| 97 |
+
"it_ops": {
|
| 98 |
+
"sla_compliance_flag": {
|
| 99 |
+
"description": "Whether ticket meets SLA target",
|
| 100 |
+
"type": "derived",
|
| 101 |
+
"logic": "compute_sla_compliance",
|
| 102 |
+
},
|
| 103 |
+
"mttr_band": {
|
| 104 |
+
"description": "Mean time to resolution band: Fast/Normal/Slow/Critical",
|
| 105 |
+
"type": "derived",
|
| 106 |
+
"logic": "classify_mttr",
|
| 107 |
+
},
|
| 108 |
+
"escalation_path": {
|
| 109 |
+
"description": "Recommended escalation path based on category and priority",
|
| 110 |
+
"type": "lookup",
|
| 111 |
+
"data": {
|
| 112 |
+
"P1-Critical": "L3 -> Manager -> VP",
|
| 113 |
+
"P2-High": "L2 -> L3 -> Manager",
|
| 114 |
+
"P3-Medium": "L1 -> L2",
|
| 115 |
+
"P4-Low": "L1",
|
| 116 |
+
},
|
| 117 |
+
},
|
| 118 |
+
"incident_severity_score": {
|
| 119 |
+
"description": "Computed severity score from priority and customer impact",
|
| 120 |
+
"type": "derived",
|
| 121 |
+
"logic": "compute_severity_score",
|
| 122 |
+
},
|
| 123 |
+
"recurring_pattern_flag": {
|
| 124 |
+
"description": "Flag indicating likely recurring issue",
|
| 125 |
+
"type": "derived",
|
| 126 |
+
"logic": "detect_recurring_pattern",
|
| 127 |
+
},
|
| 128 |
+
},
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def lookup(domain: str, source: str, row: dict) -> object:
|
| 133 |
+
"""Unified lookup/compute function for enrichment values."""
|
| 134 |
+
registry = ENRICHMENT_REGISTRY.get(domain, {})
|
| 135 |
+
source_config = registry.get(source)
|
| 136 |
+
if not source_config:
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
if source_config["type"] == "lookup":
|
| 140 |
+
# Direct lookup from static data
|
| 141 |
+
data = source_config["data"]
|
| 142 |
+
# Try various keys from the row
|
| 143 |
+
for key_col in row:
|
| 144 |
+
val = str(row.get(key_col, ""))
|
| 145 |
+
if val in data:
|
| 146 |
+
return data[val]
|
| 147 |
+
return data.get("default")
|
| 148 |
+
|
| 149 |
+
# Derived computations
|
| 150 |
+
logic = source_config["logic"]
|
| 151 |
+
compute_fn = _COMPUTE_FUNCTIONS.get(logic)
|
| 152 |
+
if compute_fn:
|
| 153 |
+
return compute_fn(row)
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# --- Computation functions ---
|
| 158 |
+
|
| 159 |
+
def _classify_salary_band(row: dict) -> str:
|
| 160 |
+
try:
|
| 161 |
+
income = float(row.get("MonthlyIncome", 0))
|
| 162 |
+
except (ValueError, TypeError):
|
| 163 |
+
return "Unknown"
|
| 164 |
+
if income < 3000:
|
| 165 |
+
return "Entry"
|
| 166 |
+
elif income < 6000:
|
| 167 |
+
return "Mid"
|
| 168 |
+
elif income < 10000:
|
| 169 |
+
return "Senior"
|
| 170 |
+
return "Executive"
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _compute_tenure_risk(row: dict) -> float:
|
| 174 |
+
try:
|
| 175 |
+
years = float(row.get("YearsAtCompany", 0))
|
| 176 |
+
except (ValueError, TypeError):
|
| 177 |
+
return 0.5
|
| 178 |
+
# Short tenure = higher risk, very long = moderate risk
|
| 179 |
+
if years < 2:
|
| 180 |
+
return 0.8
|
| 181 |
+
elif years < 5:
|
| 182 |
+
return 0.4
|
| 183 |
+
elif years < 10:
|
| 184 |
+
return 0.2
|
| 185 |
+
return 0.3
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _compute_satisfaction_index(row: dict) -> float:
|
| 189 |
+
try:
|
| 190 |
+
satisfaction = float(row.get("JobSatisfaction", 3))
|
| 191 |
+
except (ValueError, TypeError):
|
| 192 |
+
satisfaction = 3
|
| 193 |
+
return round(satisfaction / 4.0, 2)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _compute_flight_risk(row: dict) -> float:
|
| 197 |
+
tenure_risk = _compute_tenure_risk(row)
|
| 198 |
+
sat_index = _compute_satisfaction_index(row)
|
| 199 |
+
overtime = 0.3 if str(row.get("OverTime", "No")).lower() == "yes" else 0.0
|
| 200 |
+
return round(0.4 * tenure_risk + 0.4 * (1 - sat_index) + 0.2 * overtime, 2)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _classify_deal_size(row: dict) -> str:
|
| 204 |
+
try:
|
| 205 |
+
amount = float(row.get("Amount", 0))
|
| 206 |
+
except (ValueError, TypeError):
|
| 207 |
+
return "Unknown"
|
| 208 |
+
if amount < 5000:
|
| 209 |
+
return "Small"
|
| 210 |
+
elif amount < 25000:
|
| 211 |
+
return "Medium"
|
| 212 |
+
elif amount < 100000:
|
| 213 |
+
return "Large"
|
| 214 |
+
return "Enterprise"
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _compute_velocity_score(row: dict) -> float:
|
| 218 |
+
try:
|
| 219 |
+
days = float(row.get("DaysInStage", 0))
|
| 220 |
+
except (ValueError, TypeError):
|
| 221 |
+
return 0.5
|
| 222 |
+
# Benchmark: 30 days per stage
|
| 223 |
+
if days < 15:
|
| 224 |
+
return 1.0
|
| 225 |
+
elif days < 30:
|
| 226 |
+
return 0.7
|
| 227 |
+
elif days < 60:
|
| 228 |
+
return 0.4
|
| 229 |
+
return 0.1
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _compute_win_probability(row: dict) -> float:
|
| 233 |
+
stage_probs = {
|
| 234 |
+
"Prospecting": 0.10, "Qualification": 0.25, "Proposal": 0.50,
|
| 235 |
+
"Negotiation": 0.75, "Won": 1.0, "Lost": 0.0,
|
| 236 |
+
}
|
| 237 |
+
stage = str(row.get("Stage", ""))
|
| 238 |
+
base_prob = stage_probs.get(stage, 0.3)
|
| 239 |
+
velocity = _compute_velocity_score(row)
|
| 240 |
+
return round(0.7 * base_prob + 0.3 * velocity, 2)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _compute_competitive_risk(row: dict) -> float:
|
| 244 |
+
velocity = _compute_velocity_score(row)
|
| 245 |
+
return round(1.0 - velocity, 2)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _compute_schedule_risk(row: dict) -> float:
|
| 249 |
+
try:
|
| 250 |
+
pct = float(row.get("CompletionPct", 0))
|
| 251 |
+
except (ValueError, TypeError):
|
| 252 |
+
pct = 0
|
| 253 |
+
# Simple: lower completion = higher risk
|
| 254 |
+
return round(1.0 - (pct / 100.0), 2)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _compute_resource_utilization(row: dict) -> float:
|
| 258 |
+
try:
|
| 259 |
+
estimated = float(row.get("EstimatedHours", 1))
|
| 260 |
+
actual = float(row.get("ActualHours", 0))
|
| 261 |
+
except (ValueError, TypeError):
|
| 262 |
+
return 0.0
|
| 263 |
+
if estimated == 0:
|
| 264 |
+
return 0.0
|
| 265 |
+
return round(actual / estimated, 2)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _compute_dependency_depth(row: dict) -> int:
|
| 269 |
+
deps = row.get("Dependencies", "")
|
| 270 |
+
if not deps or str(deps) in ("nan", "None", ""):
|
| 271 |
+
return 0
|
| 272 |
+
return len(str(deps).split(","))
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _compute_burndown_rate(row: dict) -> float:
|
| 276 |
+
try:
|
| 277 |
+
pct = float(row.get("CompletionPct", 0))
|
| 278 |
+
estimated = float(row.get("EstimatedHours", 1))
|
| 279 |
+
actual = float(row.get("ActualHours", 0))
|
| 280 |
+
except (ValueError, TypeError):
|
| 281 |
+
return 0.5
|
| 282 |
+
if actual == 0:
|
| 283 |
+
return 0.0
|
| 284 |
+
expected_rate = pct / 100.0
|
| 285 |
+
time_rate = actual / max(estimated, 1)
|
| 286 |
+
return round(expected_rate / max(time_rate, 0.01), 2)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _compute_delay_probability(row: dict) -> float:
|
| 290 |
+
schedule_risk = _compute_schedule_risk(row)
|
| 291 |
+
burndown = _compute_burndown_rate(row)
|
| 292 |
+
return round(schedule_risk * (1.0 / max(burndown, 0.1)), 2)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _compute_sla_compliance(row: dict) -> str:
|
| 296 |
+
try:
|
| 297 |
+
sla = float(row.get("SLATarget", 24))
|
| 298 |
+
escalation = float(row.get("EscalationLevel", 0))
|
| 299 |
+
except (ValueError, TypeError):
|
| 300 |
+
return "Unknown"
|
| 301 |
+
if escalation > 2:
|
| 302 |
+
return "Breached"
|
| 303 |
+
return "Compliant"
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _classify_mttr(row: dict) -> str:
|
| 307 |
+
try:
|
| 308 |
+
escalation = float(row.get("EscalationLevel", 0))
|
| 309 |
+
except (ValueError, TypeError):
|
| 310 |
+
return "Normal"
|
| 311 |
+
if escalation == 0:
|
| 312 |
+
return "Fast"
|
| 313 |
+
elif escalation <= 1:
|
| 314 |
+
return "Normal"
|
| 315 |
+
elif escalation <= 3:
|
| 316 |
+
return "Slow"
|
| 317 |
+
return "Critical"
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def _compute_severity_score(row: dict) -> float:
|
| 321 |
+
priority_scores = {"P1-Critical": 1.0, "P2-High": 0.7, "P3-Medium": 0.4, "P4-Low": 0.1}
|
| 322 |
+
priority = str(row.get("Priority", "P3-Medium"))
|
| 323 |
+
return priority_scores.get(priority, 0.4)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _detect_recurring_pattern(row: dict) -> bool:
|
| 327 |
+
category = str(row.get("Category", ""))
|
| 328 |
+
# Simple heuristic: certain categories tend to recur
|
| 329 |
+
recurring_cats = {"Network", "Email", "Access"}
|
| 330 |
+
return category in recurring_cats
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
_COMPUTE_FUNCTIONS = {
|
| 334 |
+
"classify_salary_band": _classify_salary_band,
|
| 335 |
+
"compute_tenure_risk": _compute_tenure_risk,
|
| 336 |
+
"compute_satisfaction_index": _compute_satisfaction_index,
|
| 337 |
+
"compute_flight_risk": _compute_flight_risk,
|
| 338 |
+
"classify_deal_size": _classify_deal_size,
|
| 339 |
+
"compute_velocity_score": _compute_velocity_score,
|
| 340 |
+
"compute_win_probability": _compute_win_probability,
|
| 341 |
+
"compute_competitive_risk": _compute_competitive_risk,
|
| 342 |
+
"compute_schedule_risk": _compute_schedule_risk,
|
| 343 |
+
"compute_resource_utilization": _compute_resource_utilization,
|
| 344 |
+
"compute_dependency_depth": _compute_dependency_depth,
|
| 345 |
+
"compute_burndown_rate": _compute_burndown_rate,
|
| 346 |
+
"compute_delay_probability": _compute_delay_probability,
|
| 347 |
+
"compute_sla_compliance": _compute_sla_compliance,
|
| 348 |
+
"classify_mttr": _classify_mttr,
|
| 349 |
+
"compute_severity_score": _compute_severity_score,
|
| 350 |
+
"detect_recurring_pattern": _detect_recurring_pattern,
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def get_available_enrichments(domain: str) -> list[str]:
|
| 355 |
+
"""Return list of available enrichment source names for a domain."""
|
| 356 |
+
return list(ENRICHMENT_REGISTRY.get(domain, {}).keys())
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def get_enrichment_description(domain: str, source: str) -> str:
|
| 360 |
+
"""Get human-readable description of an enrichment source."""
|
| 361 |
+
registry = ENRICHMENT_REGISTRY.get(domain, {})
|
| 362 |
+
config = registry.get(source, {})
|
| 363 |
+
return config.get("description", "Unknown enrichment source")
|
environments/shared/enterprise_data.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-domain dataset loading, corruption injection, and DQ scoring."""
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
import string
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from .domains import DOMAINS, DomainConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_domain_data(domain: str, sample_size: Optional[int] = None) -> pd.DataFrame:
|
| 14 |
+
"""Load domain data from HF dataset or generate synthetic fallback."""
|
| 15 |
+
try:
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
ds = load_dataset("ricalanis/datasage-enterprise-raw", domain, split="train")
|
| 18 |
+
df = ds.to_pandas()
|
| 19 |
+
except Exception:
|
| 20 |
+
df = _generate_synthetic(domain)
|
| 21 |
+
|
| 22 |
+
if sample_size and len(df) > sample_size:
|
| 23 |
+
df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
|
| 24 |
+
return df
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _generate_synthetic(domain: str, n: int = 200) -> pd.DataFrame:
|
| 28 |
+
"""Generate synthetic data as fallback when HF dataset unavailable."""
|
| 29 |
+
config = DOMAINS[domain]
|
| 30 |
+
rng = np.random.default_rng(42)
|
| 31 |
+
data = {}
|
| 32 |
+
|
| 33 |
+
for col in config.columns:
|
| 34 |
+
if col in config.numeric_columns:
|
| 35 |
+
data[col] = rng.normal(50, 20, n).round(2)
|
| 36 |
+
elif col in config.categorical_columns:
|
| 37 |
+
categories = _get_categories(domain, col)
|
| 38 |
+
data[col] = rng.choice(categories, n).tolist()
|
| 39 |
+
elif "ID" in col:
|
| 40 |
+
data[col] = [f"{col[:3].upper()}-{i:04d}" for i in range(n)]
|
| 41 |
+
elif "Date" in col:
|
| 42 |
+
base = pd.Timestamp("2024-01-01")
|
| 43 |
+
data[col] = [(base + pd.Timedelta(days=int(d))).strftime("%Y-%m-%d")
|
| 44 |
+
for d in rng.integers(0, 365, n)]
|
| 45 |
+
elif "Name" in col or "Assignee" in col or "Rep" in col:
|
| 46 |
+
names = ["Alice", "Bob", "Carol", "Dan", "Eve", "Frank", "Grace", "Hank"]
|
| 47 |
+
data[col] = rng.choice(names, n).tolist()
|
| 48 |
+
else:
|
| 49 |
+
data[col] = [f"{col}_val_{i}" for i in range(n)]
|
| 50 |
+
|
| 51 |
+
return pd.DataFrame(data)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _get_categories(domain: str, col: str) -> list[str]:
|
| 55 |
+
"""Return realistic category values per domain and column."""
|
| 56 |
+
cat_map = {
|
| 57 |
+
"hr": {
|
| 58 |
+
"Department": ["Sales", "Research & Development", "Human Resources"],
|
| 59 |
+
"JobRole": ["Sales Executive", "Research Scientist", "Manager", "Lab Technician",
|
| 60 |
+
"Manufacturing Director", "Healthcare Representative"],
|
| 61 |
+
"Attrition": ["Yes", "No"],
|
| 62 |
+
"OverTime": ["Yes", "No"],
|
| 63 |
+
},
|
| 64 |
+
"sales": {
|
| 65 |
+
"Stage": ["Prospecting", "Qualification", "Proposal", "Negotiation", "Won", "Lost"],
|
| 66 |
+
"Region": ["East", "West", "Central", "North", "South"],
|
| 67 |
+
"Product": ["GTX Pro", "GTX Basic", "GTX Plus", "MG Special", "MG Advanced"],
|
| 68 |
+
"ForecastCategory": ["Pipeline", "Best Case", "Commit", "Closed"],
|
| 69 |
+
},
|
| 70 |
+
"pm": {
|
| 71 |
+
"Status": ["Not Started", "In Progress", "Completed", "On Hold", "Cancelled"],
|
| 72 |
+
"Priority": ["Critical", "High", "Medium", "Low"],
|
| 73 |
+
"RiskFlag": ["High", "Medium", "Low", "None"],
|
| 74 |
+
},
|
| 75 |
+
"it_ops": {
|
| 76 |
+
"Category": ["Hardware", "Software", "Network", "Access", "Email"],
|
| 77 |
+
"Priority": ["P1-Critical", "P2-High", "P3-Medium", "P4-Low"],
|
| 78 |
+
"Status": ["Open", "In Progress", "Resolved", "Closed", "Pending"],
|
| 79 |
+
"ResolutionType": ["Fix Applied", "Workaround", "No Fix", "Duplicate", "User Error"],
|
| 80 |
+
},
|
| 81 |
+
}
|
| 82 |
+
return cat_map.get(domain, {}).get(col, ["A", "B", "C"])
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def inject_corruption(df: pd.DataFrame, domain_config: DomainConfig,
|
| 86 |
+
rate: float = 0.15) -> pd.DataFrame:
|
| 87 |
+
"""Inject realistic data quality issues into a DataFrame."""
|
| 88 |
+
corrupted = df.copy()
|
| 89 |
+
n_rows = len(corrupted)
|
| 90 |
+
rng = np.random.default_rng(42)
|
| 91 |
+
|
| 92 |
+
# 1. Inject nulls into numeric columns
|
| 93 |
+
for col in domain_config.numeric_columns:
|
| 94 |
+
if col in corrupted.columns:
|
| 95 |
+
null_mask = rng.random(n_rows) < rate
|
| 96 |
+
corrupted.loc[null_mask, col] = np.nan
|
| 97 |
+
|
| 98 |
+
# 2. Inject type mismatches (strings in numeric columns)
|
| 99 |
+
for col in domain_config.numeric_columns:
|
| 100 |
+
if col in corrupted.columns:
|
| 101 |
+
n_bad = max(1, int(n_rows * rate * 0.3))
|
| 102 |
+
bad_idx = rng.choice(n_rows, n_bad, replace=False)
|
| 103 |
+
corrupted[col] = corrupted[col].astype(object)
|
| 104 |
+
for idx in bad_idx:
|
| 105 |
+
corrupted.iloc[idx, corrupted.columns.get_loc(col)] = rng.choice(
|
| 106 |
+
["N/A", "unknown", "#REF!", "TBD", "-"]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 3. Inject typos in categorical columns
|
| 110 |
+
for col in domain_config.categorical_columns:
|
| 111 |
+
if col in corrupted.columns:
|
| 112 |
+
n_typos = max(1, int(n_rows * rate * 0.2))
|
| 113 |
+
typo_idx = rng.choice(n_rows, n_typos, replace=False)
|
| 114 |
+
for idx in typo_idx:
|
| 115 |
+
val = str(corrupted.iloc[idx, corrupted.columns.get_loc(col)])
|
| 116 |
+
corrupted.iloc[idx, corrupted.columns.get_loc(col)] = _add_typo(val, rng)
|
| 117 |
+
|
| 118 |
+
# 4. Inject duplicates
|
| 119 |
+
n_dupes = max(1, int(n_rows * rate * 0.1))
|
| 120 |
+
dupe_idx = rng.choice(n_rows, n_dupes, replace=False)
|
| 121 |
+
dupes = corrupted.iloc[dupe_idx].copy()
|
| 122 |
+
corrupted = pd.concat([corrupted, dupes], ignore_index=True)
|
| 123 |
+
|
| 124 |
+
# 5. Inject whitespace issues
|
| 125 |
+
for col in domain_config.categorical_columns[:2]:
|
| 126 |
+
if col in corrupted.columns:
|
| 127 |
+
n_ws = max(1, int(n_rows * rate * 0.2))
|
| 128 |
+
ws_idx = rng.choice(len(corrupted), n_ws, replace=False)
|
| 129 |
+
for idx in ws_idx:
|
| 130 |
+
val = str(corrupted.iloc[idx, corrupted.columns.get_loc(col)])
|
| 131 |
+
corrupted.iloc[idx, corrupted.columns.get_loc(col)] = f" {val} "
|
| 132 |
+
|
| 133 |
+
return corrupted
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _add_typo(text: str, rng: np.random.Generator) -> str:
|
| 137 |
+
"""Add a realistic typo to a string."""
|
| 138 |
+
if len(text) < 2:
|
| 139 |
+
return text
|
| 140 |
+
typo_type = rng.choice(["swap", "delete", "insert", "case"])
|
| 141 |
+
idx = rng.integers(0, len(text))
|
| 142 |
+
if typo_type == "swap" and idx < len(text) - 1:
|
| 143 |
+
return text[:idx] + text[idx + 1] + text[idx] + text[idx + 2:]
|
| 144 |
+
elif typo_type == "delete":
|
| 145 |
+
return text[:idx] + text[idx + 1:]
|
| 146 |
+
elif typo_type == "insert":
|
| 147 |
+
char = rng.choice(list(string.ascii_lowercase))
|
| 148 |
+
return text[:idx] + char + text[idx:]
|
| 149 |
+
else:
|
| 150 |
+
return text[:idx] + text[idx].swapcase() + text[idx + 1:]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def compute_dq_score(df: pd.DataFrame, domain_config: DomainConfig) -> dict:
|
| 154 |
+
"""Compute data quality metrics: completeness, consistency, uniqueness, overall."""
|
| 155 |
+
available_cols = [c for c in domain_config.columns if c in df.columns]
|
| 156 |
+
|
| 157 |
+
# Completeness: 1 - (null ratio)
|
| 158 |
+
if available_cols:
|
| 159 |
+
null_ratio = df[available_cols].isnull().sum().sum() / (len(df) * len(available_cols))
|
| 160 |
+
completeness = 1.0 - null_ratio
|
| 161 |
+
else:
|
| 162 |
+
completeness = 1.0
|
| 163 |
+
|
| 164 |
+
# Consistency: check type correctness for numeric columns
|
| 165 |
+
consistency_scores = []
|
| 166 |
+
for col in domain_config.numeric_columns:
|
| 167 |
+
if col in df.columns:
|
| 168 |
+
valid = df[col].apply(lambda x: _is_numeric(x)).mean()
|
| 169 |
+
consistency_scores.append(valid)
|
| 170 |
+
consistency = float(np.mean(consistency_scores)) if consistency_scores else 1.0
|
| 171 |
+
|
| 172 |
+
# Uniqueness: 1 - (duplicate ratio)
|
| 173 |
+
if len(df) > 0:
|
| 174 |
+
n_dupes = df.duplicated(subset=available_cols[:5], keep='first').sum()
|
| 175 |
+
uniqueness = 1.0 - (n_dupes / len(df))
|
| 176 |
+
else:
|
| 177 |
+
uniqueness = 1.0
|
| 178 |
+
|
| 179 |
+
overall = 0.40 * completeness + 0.35 * consistency + 0.25 * uniqueness
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
"completeness": round(completeness, 4),
|
| 183 |
+
"consistency": round(consistency, 4),
|
| 184 |
+
"uniqueness": round(uniqueness, 4),
|
| 185 |
+
"overall": round(overall, 4),
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _is_numeric(val) -> bool:
|
| 190 |
+
"""Check if a value is numeric (or null, which is valid)."""
|
| 191 |
+
if pd.isna(val):
|
| 192 |
+
return True
|
| 193 |
+
try:
|
| 194 |
+
float(val)
|
| 195 |
+
return True
|
| 196 |
+
except (ValueError, TypeError):
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def compute_dq_score_with_lfs(df: pd.DataFrame, domain: str,
|
| 201 |
+
lfs: list) -> float:
|
| 202 |
+
"""Compute DQ score using Snorkel-style labeling functions with majority vote."""
|
| 203 |
+
if not lfs or len(df) == 0:
|
| 204 |
+
config = DOMAINS.get(domain)
|
| 205 |
+
if config:
|
| 206 |
+
return compute_dq_score(df, config)["overall"]
|
| 207 |
+
return 0.5
|
| 208 |
+
|
| 209 |
+
ABSTAIN, BAD, GOOD = -1, 0, 1
|
| 210 |
+
row_scores = []
|
| 211 |
+
|
| 212 |
+
for _, row in df.iterrows():
|
| 213 |
+
votes = []
|
| 214 |
+
for lf in lfs:
|
| 215 |
+
try:
|
| 216 |
+
vote = lf(row)
|
| 217 |
+
if vote != ABSTAIN:
|
| 218 |
+
votes.append(vote)
|
| 219 |
+
except Exception:
|
| 220 |
+
continue
|
| 221 |
+
if votes:
|
| 222 |
+
row_scores.append(sum(v == GOOD for v in votes) / len(votes))
|
| 223 |
+
else:
|
| 224 |
+
row_scores.append(0.5)
|
| 225 |
+
|
| 226 |
+
return float(np.mean(row_scores))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def format_preview(df: pd.DataFrame, n: int = 5) -> str:
|
| 230 |
+
"""Format first n rows as a text table."""
|
| 231 |
+
return df.head(n).to_string(index=False, max_colwidth=30)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def format_columns_info(df: pd.DataFrame, domain_config: DomainConfig) -> str:
|
| 235 |
+
"""Format column info: name, dtype, null count."""
|
| 236 |
+
lines = []
|
| 237 |
+
for col in df.columns:
|
| 238 |
+
null_count = df[col].isnull().sum()
|
| 239 |
+
dtype = str(df[col].dtype)
|
| 240 |
+
expected = "expected" if col in domain_config.columns else "extra"
|
| 241 |
+
lines.append(f"{col}: {dtype}, nulls={null_count} ({expected})")
|
| 242 |
+
return "\n".join(lines)
|
environments/shared/personas.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generalized personas for domain-independent question answering."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Persona(BaseModel):
|
| 9 |
+
name: str
|
| 10 |
+
role: str
|
| 11 |
+
focus_areas: list[str]
|
| 12 |
+
language_style: str
|
| 13 |
+
keywords: list[str]
|
| 14 |
+
anti_keywords: list[str]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
PERSONAS = [
|
| 18 |
+
Persona(
|
| 19 |
+
name="Executive",
|
| 20 |
+
role="executive",
|
| 21 |
+
focus_areas=["costs", "ROI", "strategic risk", "portfolio trends", "year-over-year"],
|
| 22 |
+
language_style="strategic-financial",
|
| 23 |
+
keywords=["revenue", "cost", "ROI", "risk", "trend", "quarter",
|
| 24 |
+
"year-over-year", "impact", "budget", "margin", "growth"],
|
| 25 |
+
anti_keywords=["I think", "maybe", "um", "idk"],
|
| 26 |
+
),
|
| 27 |
+
Persona(
|
| 28 |
+
name="Manager",
|
| 29 |
+
role="manager",
|
| 30 |
+
focus_areas=["team performance", "operational health", "process bottlenecks", "capacity"],
|
| 31 |
+
language_style="operational-actionable",
|
| 32 |
+
keywords=["team", "performance", "bottleneck", "capacity", "SLA",
|
| 33 |
+
"process", "action", "priority", "escalation", "delivery"],
|
| 34 |
+
anti_keywords=["shareholder", "valuation", "IPO"],
|
| 35 |
+
),
|
| 36 |
+
Persona(
|
| 37 |
+
name="Individual Contributor",
|
| 38 |
+
role="ic",
|
| 39 |
+
focus_areas=["personal tasks", "deadlines", "what to do next", "simple explanations"],
|
| 40 |
+
language_style="plain-personal",
|
| 41 |
+
keywords=["my", "I should", "next step", "deadline", "help",
|
| 42 |
+
"understand", "priority", "task", "assigned"],
|
| 43 |
+
anti_keywords=["KPI", "ROI", "portfolio", "strategic", "EBITDA"],
|
| 44 |
+
),
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
PERSONA_MAP = {p.role: p for p in PERSONAS}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_persona(role: str) -> Persona:
|
| 51 |
+
"""Get a persona by role name."""
|
| 52 |
+
return PERSONA_MAP[role]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def score_persona_alignment(answer: str, persona: Persona) -> float:
|
| 56 |
+
"""Score how well an answer aligns with a persona's communication style.
|
| 57 |
+
|
| 58 |
+
Returns a float 0-1 based on:
|
| 59 |
+
- Keyword density (presence of expected terms)
|
| 60 |
+
- Anti-keyword penalty (presence of terms that don't fit)
|
| 61 |
+
- Formality check (matches language style)
|
| 62 |
+
"""
|
| 63 |
+
answer_lower = answer.lower()
|
| 64 |
+
words = re.findall(r'\w+', answer_lower)
|
| 65 |
+
word_count = max(len(words), 1)
|
| 66 |
+
|
| 67 |
+
# Keyword scoring: fraction of persona keywords found
|
| 68 |
+
keyword_hits = sum(1 for kw in persona.keywords if kw.lower() in answer_lower)
|
| 69 |
+
keyword_score = min(keyword_hits / max(len(persona.keywords) * 0.3, 1), 1.0)
|
| 70 |
+
|
| 71 |
+
# Anti-keyword penalty
|
| 72 |
+
anti_hits = sum(1 for akw in persona.anti_keywords if akw.lower() in answer_lower)
|
| 73 |
+
anti_penalty = min(anti_hits * 0.15, 0.5)
|
| 74 |
+
|
| 75 |
+
# Formality check
|
| 76 |
+
formality_score = _check_formality(answer, persona.language_style)
|
| 77 |
+
|
| 78 |
+
# Combine: 50% keywords, 20% formality, 30% base (minus anti-penalty)
|
| 79 |
+
raw_score = 0.50 * keyword_score + 0.20 * formality_score + 0.30 - anti_penalty
|
| 80 |
+
return round(max(0.0, min(1.0, raw_score)), 4)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _check_formality(text: str, style: str) -> float:
|
| 84 |
+
"""Check if text formality matches the expected language style."""
|
| 85 |
+
text_lower = text.lower()
|
| 86 |
+
|
| 87 |
+
if style == "strategic-financial":
|
| 88 |
+
indicators = ["percent", "%", "million", "billion", "quarter", "fiscal",
|
| 89 |
+
"forecast", "benchmark", "metric"]
|
| 90 |
+
hits = sum(1 for ind in indicators if ind in text_lower)
|
| 91 |
+
return min(hits / 2.0, 1.0)
|
| 92 |
+
|
| 93 |
+
elif style == "operational-actionable":
|
| 94 |
+
indicators = ["action", "recommend", "should", "priority", "next steps",
|
| 95 |
+
"immediate", "plan", "schedule"]
|
| 96 |
+
hits = sum(1 for ind in indicators if ind in text_lower)
|
| 97 |
+
return min(hits / 2.0, 1.0)
|
| 98 |
+
|
| 99 |
+
elif style == "plain-personal":
|
| 100 |
+
# Plain style rewards shorter sentences and simple language
|
| 101 |
+
sentences = text.split(".")
|
| 102 |
+
avg_len = sum(len(s.split()) for s in sentences) / max(len(sentences), 1)
|
| 103 |
+
return 1.0 if avg_len < 20 else max(0.0, 1.0 - (avg_len - 20) / 30)
|
| 104 |
+
|
| 105 |
+
return 0.5
|
environments/shared/reward_utils.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward computation with cached downstream signals."""
|
| 2 |
+
|
| 3 |
+
# Cached downstream signals mapping quality buckets to historical scores.
|
| 4 |
+
# These represent how well downstream stages perform given upstream quality.
|
| 5 |
+
DOWNSTREAM_CACHE: dict[str, float] = {
|
| 6 |
+
"excellent": 0.95, # DQ > 0.90 or coverage > 0.80
|
| 7 |
+
"good": 0.75, # DQ 0.70-0.90 or coverage 0.50-0.80
|
| 8 |
+
"fair": 0.50, # DQ 0.50-0.70 or coverage 0.30-0.50
|
| 9 |
+
"poor": 0.20, # DQ < 0.50 or coverage < 0.30
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _get_downstream_bucket(score: float) -> str:
|
| 14 |
+
"""Map a score to a downstream quality bucket."""
|
| 15 |
+
if score > 0.90:
|
| 16 |
+
return "excellent"
|
| 17 |
+
elif score > 0.70:
|
| 18 |
+
return "good"
|
| 19 |
+
elif score > 0.50:
|
| 20 |
+
return "fair"
|
| 21 |
+
return "poor"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def cleaning_reward(dq_score: float, downstream_bucket: str = "") -> float:
|
| 25 |
+
"""Compute cleaning stage reward.
|
| 26 |
+
|
| 27 |
+
0.70 * dq_score + 0.30 * downstream_signal
|
| 28 |
+
"""
|
| 29 |
+
if not downstream_bucket:
|
| 30 |
+
downstream_bucket = _get_downstream_bucket(dq_score)
|
| 31 |
+
downstream = DOWNSTREAM_CACHE.get(downstream_bucket, 0.5)
|
| 32 |
+
return round(0.70 * dq_score + 0.30 * downstream, 4)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def enrichment_reward(coverage: float, downstream_bucket: str = "") -> float:
|
| 36 |
+
"""Compute enrichment stage reward.
|
| 37 |
+
|
| 38 |
+
0.50 * coverage + 0.50 * downstream_signal
|
| 39 |
+
"""
|
| 40 |
+
if not downstream_bucket:
|
| 41 |
+
downstream_bucket = _get_downstream_bucket(coverage)
|
| 42 |
+
downstream = DOWNSTREAM_CACHE.get(downstream_bucket, 0.5)
|
| 43 |
+
return round(0.50 * coverage + 0.50 * downstream, 4)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def answering_reward(faithfulness: float, persona_relevance: float,
|
| 47 |
+
patronus_score: float | None = None) -> float:
|
| 48 |
+
"""Compute answering stage reward.
|
| 49 |
+
|
| 50 |
+
Without Patronus: 0.30 * faithfulness + 0.70 * persona_relevance
|
| 51 |
+
With Patronus: 0.40 * patronus_faithfulness + 0.60 * persona_relevance
|
| 52 |
+
"""
|
| 53 |
+
if patronus_score is not None:
|
| 54 |
+
return round(0.40 * patronus_score + 0.60 * persona_relevance, 4)
|
| 55 |
+
return round(0.30 * faithfulness + 0.70 * persona_relevance, 4)
|
models.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data models for the DataSage Cleaning Environment."""
|
| 2 |
+
|
| 3 |
+
from openenv.core.env_server.types import Action, Observation
|
| 4 |
+
from pydantic import Field
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CleaningAction(Action):
|
| 9 |
+
"""Action for the Cleaning environment - a data cleaning operation."""
|
| 10 |
+
|
| 11 |
+
operation: str = Field(
|
| 12 |
+
...,
|
| 13 |
+
description="Cleaning operation: fill_null|fix_type|remove_duplicate|standardize|trim|correct_typo",
|
| 14 |
+
)
|
| 15 |
+
column: str = Field(..., description="Target column name")
|
| 16 |
+
value: Optional[str] = Field(
|
| 17 |
+
None,
|
| 18 |
+
description="Replacement value or rule (e.g., 'median', 'mode', a specific value)",
|
| 19 |
+
)
|
| 20 |
+
params: dict = Field(default_factory=dict)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CleaningObservation(Observation):
|
| 24 |
+
"""Observation from the Cleaning environment - data quality state."""
|
| 25 |
+
|
| 26 |
+
domain: str = Field(default="", description="Current domain: hr|sales|pm|it_ops")
|
| 27 |
+
data_preview: str = Field(default="", description="First 5 rows as text table")
|
| 28 |
+
dq_report: str = Field(
|
| 29 |
+
default="",
|
| 30 |
+
description="Completeness, consistency, uniqueness breakdown",
|
| 31 |
+
)
|
| 32 |
+
dq_score: float = Field(default=0.0, description="Overall DQ score 0-1")
|
| 33 |
+
columns_info: str = Field(
|
| 34 |
+
default="", description="Column names, types, null counts"
|
| 35 |
+
)
|
| 36 |
+
step_number: int = Field(default=0)
|
| 37 |
+
max_steps: int = Field(default=15)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: datasage_cleaning
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
pyproject.toml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-datasage-cleaning"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "DataSage Cleaning environment for OpenEnv"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"openenv-core[core]>=0.2.1",
|
| 12 |
+
"pandas>=2.0",
|
| 13 |
+
"numpy>=1.24",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
[project.optional-dependencies]
|
| 17 |
+
dev = [
|
| 18 |
+
"pytest>=8.0.0",
|
| 19 |
+
"pytest-cov>=4.0.0",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[project.scripts]
|
| 23 |
+
server = "datasage_cleaning.server.app:main"
|
| 24 |
+
|
| 25 |
+
[tool.setuptools]
|
| 26 |
+
include-package-data = true
|
| 27 |
+
packages = ["datasage_cleaning", "datasage_cleaning.server"]
|
| 28 |
+
package-dir = { "datasage_cleaning" = ".", "datasage_cleaning.server" = "server" }
|
server/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cleaning environment server components."""
|
| 2 |
+
|
| 3 |
+
from .cleaning_environment import CleaningEnvironment
|
| 4 |
+
|
| 5 |
+
__all__ = ["CleaningEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application for the DataSage Cleaning Environment.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
- POST /reset: Reset the environment
|
| 6 |
+
- POST /step: Execute an action
|
| 7 |
+
- GET /state: Get current environment state
|
| 8 |
+
- GET /schema: Get action/observation schemas
|
| 9 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from openenv.core.env_server.http_server import create_app
|
| 17 |
+
except Exception as e: # pragma: no cover
|
| 18 |
+
raise ImportError(
|
| 19 |
+
"openenv is required. Install with: uv sync"
|
| 20 |
+
) from e
|
| 21 |
+
|
| 22 |
+
from models import CleaningAction, CleaningObservation
|
| 23 |
+
from .cleaning_environment import CleaningEnvironment
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
app = create_app(
|
| 27 |
+
CleaningEnvironment,
|
| 28 |
+
CleaningAction,
|
| 29 |
+
CleaningObservation,
|
| 30 |
+
env_name="datasage_cleaning",
|
| 31 |
+
max_concurrent_envs=4,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 36 |
+
"""Entry point for direct execution."""
|
| 37 |
+
import uvicorn
|
| 38 |
+
|
| 39 |
+
uvicorn.run(app, host=host, port=port)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
import argparse
|
| 44 |
+
|
| 45 |
+
parser = argparse.ArgumentParser()
|
| 46 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
main(port=args.port)
|
server/cleaning_environment.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DataSage Cleaning Environment Implementation.
|
| 3 |
+
|
| 4 |
+
An RL environment where the agent must clean corrupted enterprise data
|
| 5 |
+
across 4 domains (HR, Sales, PM, IT Ops) by applying cleaning operations
|
| 6 |
+
to maximise the data quality score.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import random
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
from uuid import uuid4
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
|
| 18 |
+
# Allow imports from the project root so shared modules are reachable.
|
| 19 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
|
| 20 |
+
|
| 21 |
+
from environments.shared.domains import DOMAINS
|
| 22 |
+
from environments.shared.enterprise_data import (
|
| 23 |
+
load_domain_data,
|
| 24 |
+
inject_corruption,
|
| 25 |
+
compute_dq_score,
|
| 26 |
+
format_preview,
|
| 27 |
+
format_columns_info,
|
| 28 |
+
)
|
| 29 |
+
from environments.shared.reward_utils import cleaning_reward
|
| 30 |
+
|
| 31 |
+
from models import CleaningAction, CleaningObservation
|
| 32 |
+
from openenv.core.env_server.interfaces import Environment
|
| 33 |
+
from openenv.core.env_server.types import State
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CleaningEnvironment(Environment):
|
| 37 |
+
"""
|
| 38 |
+
Cleaning environment: the agent fixes data quality issues in a
|
| 39 |
+
50-row enterprise data batch.
|
| 40 |
+
|
| 41 |
+
Supported operations:
|
| 42 |
+
fill_null - fill missing values (median / mode / explicit value)
|
| 43 |
+
fix_type - cast a column to numeric, coercing errors to NaN
|
| 44 |
+
remove_duplicate - drop duplicate rows
|
| 45 |
+
standardize - strip whitespace and normalise case (lower / title)
|
| 46 |
+
trim - strip leading/trailing whitespace
|
| 47 |
+
correct_typo - replace a typo with a correct value
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 51 |
+
|
| 52 |
+
def __init__(self):
|
| 53 |
+
"""Initialise the cleaning environment."""
|
| 54 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 55 |
+
self._df: pd.DataFrame = pd.DataFrame()
|
| 56 |
+
self._domain_name: str = ""
|
| 57 |
+
self._domain_config = None
|
| 58 |
+
self._max_steps: int = 15
|
| 59 |
+
|
| 60 |
+
# ------------------------------------------------------------------
|
| 61 |
+
# reset
|
| 62 |
+
# ------------------------------------------------------------------
|
| 63 |
+
def reset(self) -> CleaningObservation:
|
| 64 |
+
"""Pick a random domain, load a 50-row batch, inject corruption."""
|
| 65 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 66 |
+
|
| 67 |
+
self._domain_name = random.choice(list(DOMAINS.keys()))
|
| 68 |
+
self._domain_config = DOMAINS[self._domain_name]
|
| 69 |
+
|
| 70 |
+
# Load raw data and sample 50 rows
|
| 71 |
+
raw_df = load_domain_data(self._domain_name, sample_size=50)
|
| 72 |
+
self._df = inject_corruption(raw_df, self._domain_config, rate=0.15)
|
| 73 |
+
|
| 74 |
+
dq = compute_dq_score(self._df, self._domain_config)
|
| 75 |
+
dq_report = (
|
| 76 |
+
f"completeness={dq['completeness']:.4f} "
|
| 77 |
+
f"consistency={dq['consistency']:.4f} "
|
| 78 |
+
f"uniqueness={dq['uniqueness']:.4f}"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return CleaningObservation(
|
| 82 |
+
domain=self._domain_name,
|
| 83 |
+
data_preview=format_preview(self._df),
|
| 84 |
+
dq_report=dq_report,
|
| 85 |
+
dq_score=dq["overall"],
|
| 86 |
+
columns_info=format_columns_info(self._df, self._domain_config),
|
| 87 |
+
step_number=0,
|
| 88 |
+
max_steps=self._max_steps,
|
| 89 |
+
done=False,
|
| 90 |
+
reward=0.0,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# ------------------------------------------------------------------
|
| 94 |
+
# step
|
| 95 |
+
# ------------------------------------------------------------------
|
| 96 |
+
def step(self, action: CleaningAction) -> CleaningObservation: # type: ignore[override]
|
| 97 |
+
"""Apply a single cleaning operation and return the updated state."""
|
| 98 |
+
self._state.step_count += 1
|
| 99 |
+
step = self._state.step_count
|
| 100 |
+
|
| 101 |
+
op = action.operation
|
| 102 |
+
col = action.column
|
| 103 |
+
value = action.value
|
| 104 |
+
|
| 105 |
+
# ---- apply operation ----
|
| 106 |
+
try:
|
| 107 |
+
if op == "fill_null":
|
| 108 |
+
self._apply_fill_null(col, value)
|
| 109 |
+
elif op == "fix_type":
|
| 110 |
+
self._apply_fix_type(col)
|
| 111 |
+
elif op == "remove_duplicate":
|
| 112 |
+
self._apply_remove_duplicate()
|
| 113 |
+
elif op == "standardize":
|
| 114 |
+
self._apply_standardize(col, value)
|
| 115 |
+
elif op == "trim":
|
| 116 |
+
self._apply_trim(col)
|
| 117 |
+
elif op == "correct_typo":
|
| 118 |
+
self._apply_correct_typo(col, value, action.params)
|
| 119 |
+
# unknown ops are silently ignored (no crash)
|
| 120 |
+
except Exception:
|
| 121 |
+
pass # invalid column, etc. -> no-op
|
| 122 |
+
|
| 123 |
+
# ---- compute DQ and reward ----
|
| 124 |
+
dq = compute_dq_score(self._df, self._domain_config)
|
| 125 |
+
reward = cleaning_reward(dq["overall"])
|
| 126 |
+
done = dq["overall"] > 0.95 or step >= self._max_steps
|
| 127 |
+
|
| 128 |
+
dq_report = (
|
| 129 |
+
f"completeness={dq['completeness']:.4f} "
|
| 130 |
+
f"consistency={dq['consistency']:.4f} "
|
| 131 |
+
f"uniqueness={dq['uniqueness']:.4f}"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return CleaningObservation(
|
| 135 |
+
domain=self._domain_name,
|
| 136 |
+
data_preview=format_preview(self._df),
|
| 137 |
+
dq_report=dq_report,
|
| 138 |
+
dq_score=dq["overall"],
|
| 139 |
+
columns_info=format_columns_info(self._df, self._domain_config),
|
| 140 |
+
step_number=step,
|
| 141 |
+
max_steps=self._max_steps,
|
| 142 |
+
done=done,
|
| 143 |
+
reward=reward,
|
| 144 |
+
metadata={
|
| 145 |
+
"operation": op,
|
| 146 |
+
"column": col,
|
| 147 |
+
"step": step,
|
| 148 |
+
},
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# ------------------------------------------------------------------
|
| 152 |
+
# state property
|
| 153 |
+
# ------------------------------------------------------------------
|
| 154 |
+
@property
|
| 155 |
+
def state(self) -> State:
|
| 156 |
+
"""Return current environment state."""
|
| 157 |
+
return self._state
|
| 158 |
+
|
| 159 |
+
# ------------------------------------------------------------------
|
| 160 |
+
# operation helpers
|
| 161 |
+
# ------------------------------------------------------------------
|
| 162 |
+
def _apply_fill_null(self, col: str, value: str | None) -> None:
|
| 163 |
+
if col not in self._df.columns:
|
| 164 |
+
return
|
| 165 |
+
if value == "median":
|
| 166 |
+
numeric = pd.to_numeric(self._df[col], errors="coerce")
|
| 167 |
+
fill_val = numeric.median()
|
| 168 |
+
elif value == "mode":
|
| 169 |
+
mode_vals = self._df[col].mode()
|
| 170 |
+
fill_val = mode_vals.iloc[0] if len(mode_vals) > 0 else ""
|
| 171 |
+
else:
|
| 172 |
+
fill_val = value if value is not None else ""
|
| 173 |
+
self._df[col] = self._df[col].fillna(fill_val)
|
| 174 |
+
|
| 175 |
+
def _apply_fix_type(self, col: str) -> None:
|
| 176 |
+
if col not in self._df.columns:
|
| 177 |
+
return
|
| 178 |
+
self._df[col] = pd.to_numeric(self._df[col], errors="coerce")
|
| 179 |
+
|
| 180 |
+
def _apply_remove_duplicate(self) -> None:
|
| 181 |
+
available = [c for c in self._domain_config.columns if c in self._df.columns]
|
| 182 |
+
self._df = self._df.drop_duplicates(
|
| 183 |
+
subset=available[:5], keep="first"
|
| 184 |
+
).reset_index(drop=True)
|
| 185 |
+
|
| 186 |
+
def _apply_standardize(self, col: str, value: str | None) -> None:
|
| 187 |
+
if col not in self._df.columns:
|
| 188 |
+
return
|
| 189 |
+
self._df[col] = self._df[col].astype(str).str.strip()
|
| 190 |
+
if value == "lower":
|
| 191 |
+
self._df[col] = self._df[col].str.lower()
|
| 192 |
+
elif value == "title":
|
| 193 |
+
self._df[col] = self._df[col].str.title()
|
| 194 |
+
|
| 195 |
+
def _apply_trim(self, col: str) -> None:
|
| 196 |
+
if col not in self._df.columns:
|
| 197 |
+
return
|
| 198 |
+
self._df[col] = self._df[col].astype(str).str.strip()
|
| 199 |
+
|
| 200 |
+
def _apply_correct_typo(self, col: str, value: str | None,
|
| 201 |
+
params: dict) -> None:
|
| 202 |
+
if col not in self._df.columns or value is None:
|
| 203 |
+
return
|
| 204 |
+
wrong = params.get("wrong")
|
| 205 |
+
if wrong:
|
| 206 |
+
self._df[col] = self._df[col].replace(wrong, value)
|
| 207 |
+
else:
|
| 208 |
+
# If no specific wrong value given, try to replace the most
|
| 209 |
+
# uncommon value with the provided correct value.
|
| 210 |
+
counts = self._df[col].value_counts()
|
| 211 |
+
if len(counts) > 1:
|
| 212 |
+
least_common = counts.index[-1]
|
| 213 |
+
self._df[col] = self._df[col].replace(least_common, value)
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
pandas>=2.0
|
| 5 |
+
numpy>=1.24
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|