ricalanis commited on
Commit
4f9aee9
·
verified ·
1 Parent(s): d2c74f9

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-stage build using openenv-base
2
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
3
+ FROM ${BASE_IMAGE} AS builder
4
+
5
+ WORKDIR /app
6
+
7
+ # Ensure git is available (required for installing dependencies from VCS)
8
+ RUN apt-get update && \
9
+ apt-get install -y --no-install-recommends git && \
10
+ rm -rf /var/lib/apt/lists/*
11
+
12
+ # Build argument to control whether we're building standalone or in-repo
13
+ ARG BUILD_MODE=in-repo
14
+ ARG ENV_NAME=datasage_answering
15
+
16
+ # Copy environment code (always at root of build context)
17
+ COPY . /app/env
18
+
19
+ WORKDIR /app/env
20
+
21
+ # Ensure uv is available (for local builds where base image lacks it)
22
+ RUN if ! command -v uv >/dev/null 2>&1; then \
23
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
24
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
25
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
26
+ fi
27
+
28
+ # Install dependencies using uv sync
29
+ RUN --mount=type=cache,target=/root/.cache/uv \
30
+ if [ -f uv.lock ]; then \
31
+ uv sync --frozen --no-install-project --no-editable; \
32
+ else \
33
+ uv sync --no-install-project --no-editable; \
34
+ fi
35
+
36
+ RUN --mount=type=cache,target=/root/.cache/uv \
37
+ if [ -f uv.lock ]; then \
38
+ uv sync --frozen --no-editable; \
39
+ else \
40
+ uv sync --no-editable; \
41
+ fi
42
+
43
+ # Final runtime stage
44
+ FROM ${BASE_IMAGE}
45
+
46
+ WORKDIR /app
47
+
48
+ # Copy the virtual environment from builder
49
+ COPY --from=builder /app/env/.venv /app/.venv
50
+
51
+ # Copy the environment code
52
+ COPY --from=builder /app/env /app/env
53
+
54
+ # Set PATH to use the virtual environment
55
+ ENV PATH="/app/.venv/bin:$PATH"
56
+
57
+ # Set PYTHONPATH so imports work correctly
58
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
59
+
60
+ # Health check
61
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
62
+ CMD curl -f http://localhost:8000/health || exit 1
63
+
64
+ # Run the FastAPI server
65
+ ENV ENABLE_WEB_INTERFACE=true
66
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,47 @@
1
  ---
2
- title: Datasage Answering
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: docker
7
  pinned: false
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DataSage Answering Environment
3
+ emoji: "\U0001F4CA"
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ - datasage
13
  ---
14
 
15
+ # DataSage Answering Environment
16
+
17
+ A single-step RL environment where an agent answers enterprise data questions tailored to a specific persona (Executive, Manager, Individual Contributor) using enriched data context across 4 domains (HR, Sales, PM, IT Ops).
18
+
19
+ ## Quick Start
20
+
21
+ ```python
22
+ from environments.answering.models import AnsweringAction
23
+ from environments.answering.server.answering_environment import AnsweringEnvironment
24
+
25
+ env = AnsweringEnvironment()
26
+ obs = env.reset()
27
+ print(f"Domain: {obs.domain}, Persona: {obs.persona}")
28
+ print(f"Question: {obs.question}")
29
+
30
+ action = AnsweringAction(
31
+ answer="Based on the data, key trends show...",
32
+ cited_columns=obs.available_columns[:3],
33
+ reasoning="Analyzed available columns for patterns."
34
+ )
35
+ result = env.step(action)
36
+ print(f"Reward: {result.reward}, Done: {result.done}")
37
+ ```
38
+
39
+ ## Reward
40
+
41
+ The reward combines:
42
+ - **Faithfulness** (0-1): Are cited columns valid? Does the answer reference real data values?
43
+ - **Persona relevance** (0-1): Does the answer match the persona's language style and focus areas?
44
+ - **Patronus score** (optional): If `PATRONUS_API_KEY` is set, uses Patronus Lynx for hallucination detection.
45
+
46
+ Without Patronus: `0.30 * faithfulness + 0.70 * persona_relevance`
47
+ With Patronus: `0.40 * patronus_score + 0.60 * persona_relevance`
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """DataSage Answering Environment."""
client.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DataSage Answering Environment Client."""
2
+
3
+ from typing import Dict
4
+
5
+ from openenv.core import EnvClient
6
+ from openenv.core.client_types import StepResult
7
+ from openenv.core.env_server.types import State
8
+
9
+ from .models import AnsweringAction, AnsweringObservation
10
+
11
+
12
+ class AnsweringEnv(
13
+ EnvClient[AnsweringAction, AnsweringObservation, State]
14
+ ):
15
+ """
16
+ Client for the DataSage Answering Environment.
17
+
18
+ This client maintains a persistent WebSocket connection to the environment
19
+ server, enabling efficient multi-step interactions with lower latency.
20
+ Each client instance has its own dedicated environment session on the server.
21
+
22
+ Example:
23
+ >>> with AnsweringEnv(base_url="http://localhost:8000") as client:
24
+ ... result = client.reset()
25
+ ... print(result.observation.question)
26
+ ...
27
+ ... result = client.step(AnsweringAction(
28
+ ... answer="Based on the data...",
29
+ ... cited_columns=["col1"],
30
+ ... reasoning="Analysis reasoning"
31
+ ... ))
32
+ ... print(result.observation.reward)
33
+ """
34
+
35
+ def _step_payload(self, action: AnsweringAction) -> Dict:
36
+ """
37
+ Convert AnsweringAction to JSON payload for step message.
38
+
39
+ Args:
40
+ action: AnsweringAction instance
41
+
42
+ Returns:
43
+ Dictionary representation suitable for JSON encoding
44
+ """
45
+ return {
46
+ "answer": action.answer,
47
+ "cited_columns": action.cited_columns,
48
+ "reasoning": action.reasoning,
49
+ }
50
+
51
+ def _parse_result(self, payload: Dict) -> StepResult[AnsweringObservation]:
52
+ """
53
+ Parse server response into StepResult[AnsweringObservation].
54
+
55
+ Args:
56
+ payload: JSON response data from server
57
+
58
+ Returns:
59
+ StepResult with AnsweringObservation
60
+ """
61
+ obs_data = payload.get("observation", {})
62
+ observation = AnsweringObservation(
63
+ domain=obs_data.get("domain", ""),
64
+ dataset_summary=obs_data.get("dataset_summary", ""),
65
+ persona=obs_data.get("persona", ""),
66
+ persona_description=obs_data.get("persona_description", ""),
67
+ question=obs_data.get("question", ""),
68
+ available_columns=obs_data.get("available_columns", []),
69
+ column_stats=obs_data.get("column_stats", ""),
70
+ done=payload.get("done", False),
71
+ reward=payload.get("reward"),
72
+ metadata=obs_data.get("metadata", {}),
73
+ )
74
+
75
+ return StepResult(
76
+ observation=observation,
77
+ reward=payload.get("reward"),
78
+ done=payload.get("done", False),
79
+ )
80
+
81
+ def _parse_state(self, payload: Dict) -> State:
82
+ """
83
+ Parse server response into State object.
84
+
85
+ Args:
86
+ payload: JSON response from state request
87
+
88
+ Returns:
89
+ State object with episode_id and step_count
90
+ """
91
+ return State(
92
+ episode_id=payload.get("episode_id"),
93
+ step_count=payload.get("step_count", 0),
94
+ )
environments/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # DataSage environments package
environments/shared/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared utilities for DataSage multi-domain enterprise environments."""
2
+
3
+ from .domains import DOMAINS, DomainConfig
4
+ from .personas import PERSONAS, Persona, score_persona_alignment
5
+ from .reward_utils import cleaning_reward, enrichment_reward, answering_reward
6
+
7
+ __all__ = [
8
+ "DOMAINS", "DomainConfig",
9
+ "PERSONAS", "Persona", "score_persona_alignment",
10
+ "cleaning_reward", "enrichment_reward", "answering_reward",
11
+ ]
environments/shared/domains.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Domain registry for the 4 enterprise data domains."""
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class DomainConfig(BaseModel):
7
+ name: str
8
+ display_name: str
9
+ dataset_key: str
10
+ columns: list[str]
11
+ numeric_columns: list[str]
12
+ categorical_columns: list[str]
13
+ possible_enrichments: list[str]
14
+ example_questions: list[str]
15
+
16
+
17
+ DOMAINS = {
18
+ "hr": DomainConfig(
19
+ name="hr",
20
+ display_name="HR & People",
21
+ dataset_key="hr",
22
+ columns=[
23
+ "EmployeeID", "Age", "Department", "JobRole", "MonthlyIncome",
24
+ "YearsAtCompany", "Attrition", "JobSatisfaction", "OverTime",
25
+ "DistanceFromHome", "Education", "PerformanceRating",
26
+ ],
27
+ numeric_columns=["Age", "MonthlyIncome", "YearsAtCompany", "DistanceFromHome"],
28
+ categorical_columns=["Department", "JobRole", "Attrition", "OverTime"],
29
+ possible_enrichments=[
30
+ "salary_band", "tenure_risk", "satisfaction_index",
31
+ "industry_benchmark", "flight_risk_score",
32
+ ],
33
+ example_questions=[
34
+ "Which departments have the highest attrition rates?",
35
+ "What factors correlate most with employee turnover?",
36
+ "How does overtime affect job satisfaction?",
37
+ "What is the salary distribution across job roles?",
38
+ "Which employees are at highest flight risk?",
39
+ ],
40
+ ),
41
+ "sales": DomainConfig(
42
+ name="sales",
43
+ display_name="Sales & Revenue",
44
+ dataset_key="sales",
45
+ columns=[
46
+ "DealID", "AccountName", "Stage", "Amount", "CloseDate",
47
+ "Rep", "Product", "Region", "LeadSource", "DaysInStage",
48
+ "Probability", "ForecastCategory",
49
+ ],
50
+ numeric_columns=["Amount", "DaysInStage", "Probability"],
51
+ categorical_columns=["Stage", "Region", "Product", "ForecastCategory"],
52
+ possible_enrichments=[
53
+ "deal_size_category", "velocity_score", "win_probability_model",
54
+ "industry_code", "competitive_risk",
55
+ ],
56
+ example_questions=[
57
+ "What's our pipeline health for this quarter?",
58
+ "Which deals are at risk of slipping?",
59
+ "What's the average deal velocity by region?",
60
+ "Which reps are below quota?",
61
+ "What's the conversion rate by lead source?",
62
+ ],
63
+ ),
64
+ "pm": DomainConfig(
65
+ name="pm",
66
+ display_name="Project Management",
67
+ dataset_key="pm",
68
+ columns=[
69
+ "TaskID", "ProjectName", "Assignee", "Status", "Priority",
70
+ "DueDate", "EstimatedHours", "ActualHours", "Dependencies",
71
+ "Milestone", "RiskFlag", "CompletionPct",
72
+ ],
73
+ numeric_columns=["EstimatedHours", "ActualHours", "CompletionPct"],
74
+ categorical_columns=["Status", "Priority", "RiskFlag"],
75
+ possible_enrichments=[
76
+ "schedule_risk_score", "resource_utilization",
77
+ "dependency_chain_depth", "burndown_rate", "delay_probability",
78
+ ],
79
+ example_questions=[
80
+ "Which projects are at risk of missing deadlines?",
81
+ "How is resource utilization across teams?",
82
+ "What's the burndown rate for the current sprint?",
83
+ "Which tasks are blocking the most downstream work?",
84
+ "What's our on-time delivery rate?",
85
+ ],
86
+ ),
87
+ "it_ops": DomainConfig(
88
+ name="it_ops",
89
+ display_name="IT Operations",
90
+ dataset_key="it_ops",
91
+ columns=[
92
+ "TicketID", "Category", "Priority", "Status", "Assignee",
93
+ "CreatedDate", "ResolvedDate", "SLATarget", "EscalationLevel",
94
+ "AffectedSystem", "ResolutionType", "CustomerImpact",
95
+ ],
96
+ numeric_columns=["SLATarget", "EscalationLevel"],
97
+ categorical_columns=["Category", "Priority", "Status", "ResolutionType"],
98
+ possible_enrichments=[
99
+ "sla_compliance_flag", "mttr_band", "escalation_path",
100
+ "incident_severity_score", "recurring_pattern_flag",
101
+ ],
102
+ example_questions=[
103
+ "What's our SLA compliance rate this month?",
104
+ "Which systems have the most incidents?",
105
+ "What's the mean time to resolution trend?",
106
+ "How many tickets are breaching SLA?",
107
+ "What are the most common root causes?",
108
+ ],
109
+ ),
110
+ }
environments/shared/enrichment_sources.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Static enrichment lookup tables per domain (no API calls)."""
2
+
3
+ import numpy as np
4
+
5
+ # Enrichment registry: domain -> source -> lookup function or static data
6
+ ENRICHMENT_REGISTRY: dict[str, dict[str, dict]] = {
7
+ "hr": {
8
+ "salary_band": {
9
+ "description": "BLS salary band classification based on monthly income",
10
+ "type": "derived",
11
+ "logic": "classify_salary_band",
12
+ },
13
+ "tenure_risk": {
14
+ "description": "Tenure-based flight risk score",
15
+ "type": "derived",
16
+ "logic": "compute_tenure_risk",
17
+ },
18
+ "satisfaction_index": {
19
+ "description": "Composite satisfaction index from multiple factors",
20
+ "type": "derived",
21
+ "logic": "compute_satisfaction_index",
22
+ },
23
+ "industry_benchmark": {
24
+ "description": "Industry benchmark salary percentile",
25
+ "type": "lookup",
26
+ "data": {
27
+ "Sales Executive": 65000, "Research Scientist": 72000,
28
+ "Manager": 85000, "Lab Technician": 45000,
29
+ "Manufacturing Director": 95000, "Healthcare Representative": 55000,
30
+ "Human Resources": 60000,
31
+ },
32
+ },
33
+ "flight_risk_score": {
34
+ "description": "Combined flight risk from satisfaction, tenure, overtime",
35
+ "type": "derived",
36
+ "logic": "compute_flight_risk",
37
+ },
38
+ },
39
+ "sales": {
40
+ "deal_size_category": {
41
+ "description": "Categorize deal by amount: Small/Medium/Large/Enterprise",
42
+ "type": "derived",
43
+ "logic": "classify_deal_size",
44
+ },
45
+ "velocity_score": {
46
+ "description": "Deal velocity based on days in stage vs benchmark",
47
+ "type": "derived",
48
+ "logic": "compute_velocity_score",
49
+ },
50
+ "win_probability_model": {
51
+ "description": "Heuristic win probability based on stage + days",
52
+ "type": "derived",
53
+ "logic": "compute_win_probability",
54
+ },
55
+ "industry_code": {
56
+ "description": "Industry classification code from account name patterns",
57
+ "type": "lookup",
58
+ "data": {
59
+ "Tech": "SIC-7372", "Healthcare": "SIC-8011",
60
+ "Finance": "SIC-6020", "Retail": "SIC-5311",
61
+ "Manufacturing": "SIC-3559", "default": "SIC-9999",
62
+ },
63
+ },
64
+ "competitive_risk": {
65
+ "description": "Competitive risk score based on deal stage and velocity",
66
+ "type": "derived",
67
+ "logic": "compute_competitive_risk",
68
+ },
69
+ },
70
+ "pm": {
71
+ "schedule_risk_score": {
72
+ "description": "Risk of schedule slippage based on progress vs due date",
73
+ "type": "derived",
74
+ "logic": "compute_schedule_risk",
75
+ },
76
+ "resource_utilization": {
77
+ "description": "Resource utilization ratio: actual/estimated hours",
78
+ "type": "derived",
79
+ "logic": "compute_resource_utilization",
80
+ },
81
+ "dependency_chain_depth": {
82
+ "description": "Depth of dependency chain for task",
83
+ "type": "derived",
84
+ "logic": "compute_dependency_depth",
85
+ },
86
+ "burndown_rate": {
87
+ "description": "Task completion rate relative to plan",
88
+ "type": "derived",
89
+ "logic": "compute_burndown_rate",
90
+ },
91
+ "delay_probability": {
92
+ "description": "Probability of delay based on current trajectory",
93
+ "type": "derived",
94
+ "logic": "compute_delay_probability",
95
+ },
96
+ },
97
+ "it_ops": {
98
+ "sla_compliance_flag": {
99
+ "description": "Whether ticket meets SLA target",
100
+ "type": "derived",
101
+ "logic": "compute_sla_compliance",
102
+ },
103
+ "mttr_band": {
104
+ "description": "Mean time to resolution band: Fast/Normal/Slow/Critical",
105
+ "type": "derived",
106
+ "logic": "classify_mttr",
107
+ },
108
+ "escalation_path": {
109
+ "description": "Recommended escalation path based on category and priority",
110
+ "type": "lookup",
111
+ "data": {
112
+ "P1-Critical": "L3 -> Manager -> VP",
113
+ "P2-High": "L2 -> L3 -> Manager",
114
+ "P3-Medium": "L1 -> L2",
115
+ "P4-Low": "L1",
116
+ },
117
+ },
118
+ "incident_severity_score": {
119
+ "description": "Computed severity score from priority and customer impact",
120
+ "type": "derived",
121
+ "logic": "compute_severity_score",
122
+ },
123
+ "recurring_pattern_flag": {
124
+ "description": "Flag indicating likely recurring issue",
125
+ "type": "derived",
126
+ "logic": "detect_recurring_pattern",
127
+ },
128
+ },
129
+ }
130
+
131
+
132
+ def lookup(domain: str, source: str, row: dict) -> object:
133
+ """Unified lookup/compute function for enrichment values."""
134
+ registry = ENRICHMENT_REGISTRY.get(domain, {})
135
+ source_config = registry.get(source)
136
+ if not source_config:
137
+ return None
138
+
139
+ if source_config["type"] == "lookup":
140
+ # Direct lookup from static data
141
+ data = source_config["data"]
142
+ # Try various keys from the row
143
+ for key_col in row:
144
+ val = str(row.get(key_col, ""))
145
+ if val in data:
146
+ return data[val]
147
+ return data.get("default")
148
+
149
+ # Derived computations
150
+ logic = source_config["logic"]
151
+ compute_fn = _COMPUTE_FUNCTIONS.get(logic)
152
+ if compute_fn:
153
+ return compute_fn(row)
154
+ return None
155
+
156
+
157
+ # --- Computation functions ---
158
+
159
+ def _classify_salary_band(row: dict) -> str:
160
+ try:
161
+ income = float(row.get("MonthlyIncome", 0))
162
+ except (ValueError, TypeError):
163
+ return "Unknown"
164
+ if income < 3000:
165
+ return "Entry"
166
+ elif income < 6000:
167
+ return "Mid"
168
+ elif income < 10000:
169
+ return "Senior"
170
+ return "Executive"
171
+
172
+
173
+ def _compute_tenure_risk(row: dict) -> float:
174
+ try:
175
+ years = float(row.get("YearsAtCompany", 0))
176
+ except (ValueError, TypeError):
177
+ return 0.5
178
+ # Short tenure = higher risk, very long = moderate risk
179
+ if years < 2:
180
+ return 0.8
181
+ elif years < 5:
182
+ return 0.4
183
+ elif years < 10:
184
+ return 0.2
185
+ return 0.3
186
+
187
+
188
+ def _compute_satisfaction_index(row: dict) -> float:
189
+ try:
190
+ satisfaction = float(row.get("JobSatisfaction", 3))
191
+ except (ValueError, TypeError):
192
+ satisfaction = 3
193
+ return round(satisfaction / 4.0, 2)
194
+
195
+
196
+ def _compute_flight_risk(row: dict) -> float:
197
+ tenure_risk = _compute_tenure_risk(row)
198
+ sat_index = _compute_satisfaction_index(row)
199
+ overtime = 0.3 if str(row.get("OverTime", "No")).lower() == "yes" else 0.0
200
+ return round(0.4 * tenure_risk + 0.4 * (1 - sat_index) + 0.2 * overtime, 2)
201
+
202
+
203
+ def _classify_deal_size(row: dict) -> str:
204
+ try:
205
+ amount = float(row.get("Amount", 0))
206
+ except (ValueError, TypeError):
207
+ return "Unknown"
208
+ if amount < 5000:
209
+ return "Small"
210
+ elif amount < 25000:
211
+ return "Medium"
212
+ elif amount < 100000:
213
+ return "Large"
214
+ return "Enterprise"
215
+
216
+
217
+ def _compute_velocity_score(row: dict) -> float:
218
+ try:
219
+ days = float(row.get("DaysInStage", 0))
220
+ except (ValueError, TypeError):
221
+ return 0.5
222
+ # Benchmark: 30 days per stage
223
+ if days < 15:
224
+ return 1.0
225
+ elif days < 30:
226
+ return 0.7
227
+ elif days < 60:
228
+ return 0.4
229
+ return 0.1
230
+
231
+
232
+ def _compute_win_probability(row: dict) -> float:
233
+ stage_probs = {
234
+ "Prospecting": 0.10, "Qualification": 0.25, "Proposal": 0.50,
235
+ "Negotiation": 0.75, "Won": 1.0, "Lost": 0.0,
236
+ }
237
+ stage = str(row.get("Stage", ""))
238
+ base_prob = stage_probs.get(stage, 0.3)
239
+ velocity = _compute_velocity_score(row)
240
+ return round(0.7 * base_prob + 0.3 * velocity, 2)
241
+
242
+
243
+ def _compute_competitive_risk(row: dict) -> float:
244
+ velocity = _compute_velocity_score(row)
245
+ return round(1.0 - velocity, 2)
246
+
247
+
248
+ def _compute_schedule_risk(row: dict) -> float:
249
+ try:
250
+ pct = float(row.get("CompletionPct", 0))
251
+ except (ValueError, TypeError):
252
+ pct = 0
253
+ # Simple: lower completion = higher risk
254
+ return round(1.0 - (pct / 100.0), 2)
255
+
256
+
257
+ def _compute_resource_utilization(row: dict) -> float:
258
+ try:
259
+ estimated = float(row.get("EstimatedHours", 1))
260
+ actual = float(row.get("ActualHours", 0))
261
+ except (ValueError, TypeError):
262
+ return 0.0
263
+ if estimated == 0:
264
+ return 0.0
265
+ return round(actual / estimated, 2)
266
+
267
+
268
+ def _compute_dependency_depth(row: dict) -> int:
269
+ deps = row.get("Dependencies", "")
270
+ if not deps or str(deps) in ("nan", "None", ""):
271
+ return 0
272
+ return len(str(deps).split(","))
273
+
274
+
275
+ def _compute_burndown_rate(row: dict) -> float:
276
+ try:
277
+ pct = float(row.get("CompletionPct", 0))
278
+ estimated = float(row.get("EstimatedHours", 1))
279
+ actual = float(row.get("ActualHours", 0))
280
+ except (ValueError, TypeError):
281
+ return 0.5
282
+ if actual == 0:
283
+ return 0.0
284
+ expected_rate = pct / 100.0
285
+ time_rate = actual / max(estimated, 1)
286
+ return round(expected_rate / max(time_rate, 0.01), 2)
287
+
288
+
289
+ def _compute_delay_probability(row: dict) -> float:
290
+ schedule_risk = _compute_schedule_risk(row)
291
+ burndown = _compute_burndown_rate(row)
292
+ return round(schedule_risk * (1.0 / max(burndown, 0.1)), 2)
293
+
294
+
295
+ def _compute_sla_compliance(row: dict) -> str:
296
+ try:
297
+ sla = float(row.get("SLATarget", 24))
298
+ escalation = float(row.get("EscalationLevel", 0))
299
+ except (ValueError, TypeError):
300
+ return "Unknown"
301
+ if escalation > 2:
302
+ return "Breached"
303
+ return "Compliant"
304
+
305
+
306
+ def _classify_mttr(row: dict) -> str:
307
+ try:
308
+ escalation = float(row.get("EscalationLevel", 0))
309
+ except (ValueError, TypeError):
310
+ return "Normal"
311
+ if escalation == 0:
312
+ return "Fast"
313
+ elif escalation <= 1:
314
+ return "Normal"
315
+ elif escalation <= 3:
316
+ return "Slow"
317
+ return "Critical"
318
+
319
+
320
+ def _compute_severity_score(row: dict) -> float:
321
+ priority_scores = {"P1-Critical": 1.0, "P2-High": 0.7, "P3-Medium": 0.4, "P4-Low": 0.1}
322
+ priority = str(row.get("Priority", "P3-Medium"))
323
+ return priority_scores.get(priority, 0.4)
324
+
325
+
326
+ def _detect_recurring_pattern(row: dict) -> bool:
327
+ category = str(row.get("Category", ""))
328
+ # Simple heuristic: certain categories tend to recur
329
+ recurring_cats = {"Network", "Email", "Access"}
330
+ return category in recurring_cats
331
+
332
+
333
+ _COMPUTE_FUNCTIONS = {
334
+ "classify_salary_band": _classify_salary_band,
335
+ "compute_tenure_risk": _compute_tenure_risk,
336
+ "compute_satisfaction_index": _compute_satisfaction_index,
337
+ "compute_flight_risk": _compute_flight_risk,
338
+ "classify_deal_size": _classify_deal_size,
339
+ "compute_velocity_score": _compute_velocity_score,
340
+ "compute_win_probability": _compute_win_probability,
341
+ "compute_competitive_risk": _compute_competitive_risk,
342
+ "compute_schedule_risk": _compute_schedule_risk,
343
+ "compute_resource_utilization": _compute_resource_utilization,
344
+ "compute_dependency_depth": _compute_dependency_depth,
345
+ "compute_burndown_rate": _compute_burndown_rate,
346
+ "compute_delay_probability": _compute_delay_probability,
347
+ "compute_sla_compliance": _compute_sla_compliance,
348
+ "classify_mttr": _classify_mttr,
349
+ "compute_severity_score": _compute_severity_score,
350
+ "detect_recurring_pattern": _detect_recurring_pattern,
351
+ }
352
+
353
+
354
+ def get_available_enrichments(domain: str) -> list[str]:
355
+ """Return list of available enrichment source names for a domain."""
356
+ return list(ENRICHMENT_REGISTRY.get(domain, {}).keys())
357
+
358
+
359
+ def get_enrichment_description(domain: str, source: str) -> str:
360
+ """Get human-readable description of an enrichment source."""
361
+ registry = ENRICHMENT_REGISTRY.get(domain, {})
362
+ config = registry.get(source, {})
363
+ return config.get("description", "Unknown enrichment source")
environments/shared/enterprise_data.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-domain dataset loading, corruption injection, and DQ scoring."""
2
+
3
+ import random
4
+ import string
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from .domains import DOMAINS, DomainConfig
11
+
12
+
13
+ def load_domain_data(domain: str, sample_size: Optional[int] = None) -> pd.DataFrame:
14
+ """Load domain data from HF dataset or generate synthetic fallback."""
15
+ try:
16
+ from datasets import load_dataset
17
+ ds = load_dataset("ricalanis/datasage-enterprise-raw", domain, split="train")
18
+ df = ds.to_pandas()
19
+ except Exception:
20
+ df = _generate_synthetic(domain)
21
+
22
+ if sample_size and len(df) > sample_size:
23
+ df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
24
+ return df
25
+
26
+
27
+ def _generate_synthetic(domain: str, n: int = 200) -> pd.DataFrame:
28
+ """Generate synthetic data as fallback when HF dataset unavailable."""
29
+ config = DOMAINS[domain]
30
+ rng = np.random.default_rng(42)
31
+ data = {}
32
+
33
+ for col in config.columns:
34
+ if col in config.numeric_columns:
35
+ data[col] = rng.normal(50, 20, n).round(2)
36
+ elif col in config.categorical_columns:
37
+ categories = _get_categories(domain, col)
38
+ data[col] = rng.choice(categories, n).tolist()
39
+ elif "ID" in col:
40
+ data[col] = [f"{col[:3].upper()}-{i:04d}" for i in range(n)]
41
+ elif "Date" in col:
42
+ base = pd.Timestamp("2024-01-01")
43
+ data[col] = [(base + pd.Timedelta(days=int(d))).strftime("%Y-%m-%d")
44
+ for d in rng.integers(0, 365, n)]
45
+ elif "Name" in col or "Assignee" in col or "Rep" in col:
46
+ names = ["Alice", "Bob", "Carol", "Dan", "Eve", "Frank", "Grace", "Hank"]
47
+ data[col] = rng.choice(names, n).tolist()
48
+ else:
49
+ data[col] = [f"{col}_val_{i}" for i in range(n)]
50
+
51
+ return pd.DataFrame(data)
52
+
53
+
54
+ def _get_categories(domain: str, col: str) -> list[str]:
55
+ """Return realistic category values per domain and column."""
56
+ cat_map = {
57
+ "hr": {
58
+ "Department": ["Sales", "Research & Development", "Human Resources"],
59
+ "JobRole": ["Sales Executive", "Research Scientist", "Manager", "Lab Technician",
60
+ "Manufacturing Director", "Healthcare Representative"],
61
+ "Attrition": ["Yes", "No"],
62
+ "OverTime": ["Yes", "No"],
63
+ },
64
+ "sales": {
65
+ "Stage": ["Prospecting", "Qualification", "Proposal", "Negotiation", "Won", "Lost"],
66
+ "Region": ["East", "West", "Central", "North", "South"],
67
+ "Product": ["GTX Pro", "GTX Basic", "GTX Plus", "MG Special", "MG Advanced"],
68
+ "ForecastCategory": ["Pipeline", "Best Case", "Commit", "Closed"],
69
+ },
70
+ "pm": {
71
+ "Status": ["Not Started", "In Progress", "Completed", "On Hold", "Cancelled"],
72
+ "Priority": ["Critical", "High", "Medium", "Low"],
73
+ "RiskFlag": ["High", "Medium", "Low", "None"],
74
+ },
75
+ "it_ops": {
76
+ "Category": ["Hardware", "Software", "Network", "Access", "Email"],
77
+ "Priority": ["P1-Critical", "P2-High", "P3-Medium", "P4-Low"],
78
+ "Status": ["Open", "In Progress", "Resolved", "Closed", "Pending"],
79
+ "ResolutionType": ["Fix Applied", "Workaround", "No Fix", "Duplicate", "User Error"],
80
+ },
81
+ }
82
+ return cat_map.get(domain, {}).get(col, ["A", "B", "C"])
83
+
84
+
85
+ def inject_corruption(df: pd.DataFrame, domain_config: DomainConfig,
86
+ rate: float = 0.15) -> pd.DataFrame:
87
+ """Inject realistic data quality issues into a DataFrame."""
88
+ corrupted = df.copy()
89
+ n_rows = len(corrupted)
90
+ rng = np.random.default_rng(42)
91
+
92
+ # 1. Inject nulls into numeric columns
93
+ for col in domain_config.numeric_columns:
94
+ if col in corrupted.columns:
95
+ null_mask = rng.random(n_rows) < rate
96
+ corrupted.loc[null_mask, col] = np.nan
97
+
98
+ # 2. Inject type mismatches (strings in numeric columns)
99
+ for col in domain_config.numeric_columns:
100
+ if col in corrupted.columns:
101
+ n_bad = max(1, int(n_rows * rate * 0.3))
102
+ bad_idx = rng.choice(n_rows, n_bad, replace=False)
103
+ corrupted[col] = corrupted[col].astype(object)
104
+ for idx in bad_idx:
105
+ corrupted.iloc[idx, corrupted.columns.get_loc(col)] = rng.choice(
106
+ ["N/A", "unknown", "#REF!", "TBD", "-"]
107
+ )
108
+
109
+ # 3. Inject typos in categorical columns
110
+ for col in domain_config.categorical_columns:
111
+ if col in corrupted.columns:
112
+ n_typos = max(1, int(n_rows * rate * 0.2))
113
+ typo_idx = rng.choice(n_rows, n_typos, replace=False)
114
+ for idx in typo_idx:
115
+ val = str(corrupted.iloc[idx, corrupted.columns.get_loc(col)])
116
+ corrupted.iloc[idx, corrupted.columns.get_loc(col)] = _add_typo(val, rng)
117
+
118
+ # 4. Inject duplicates
119
+ n_dupes = max(1, int(n_rows * rate * 0.1))
120
+ dupe_idx = rng.choice(n_rows, n_dupes, replace=False)
121
+ dupes = corrupted.iloc[dupe_idx].copy()
122
+ corrupted = pd.concat([corrupted, dupes], ignore_index=True)
123
+
124
+ # 5. Inject whitespace issues
125
+ for col in domain_config.categorical_columns[:2]:
126
+ if col in corrupted.columns:
127
+ n_ws = max(1, int(n_rows * rate * 0.2))
128
+ ws_idx = rng.choice(len(corrupted), n_ws, replace=False)
129
+ for idx in ws_idx:
130
+ val = str(corrupted.iloc[idx, corrupted.columns.get_loc(col)])
131
+ corrupted.iloc[idx, corrupted.columns.get_loc(col)] = f" {val} "
132
+
133
+ return corrupted
134
+
135
+
136
+ def _add_typo(text: str, rng: np.random.Generator) -> str:
137
+ """Add a realistic typo to a string."""
138
+ if len(text) < 2:
139
+ return text
140
+ typo_type = rng.choice(["swap", "delete", "insert", "case"])
141
+ idx = rng.integers(0, len(text))
142
+ if typo_type == "swap" and idx < len(text) - 1:
143
+ return text[:idx] + text[idx + 1] + text[idx] + text[idx + 2:]
144
+ elif typo_type == "delete":
145
+ return text[:idx] + text[idx + 1:]
146
+ elif typo_type == "insert":
147
+ char = rng.choice(list(string.ascii_lowercase))
148
+ return text[:idx] + char + text[idx:]
149
+ else:
150
+ return text[:idx] + text[idx].swapcase() + text[idx + 1:]
151
+
152
+
153
+ def compute_dq_score(df: pd.DataFrame, domain_config: DomainConfig) -> dict:
154
+ """Compute data quality metrics: completeness, consistency, uniqueness, overall."""
155
+ available_cols = [c for c in domain_config.columns if c in df.columns]
156
+
157
+ # Completeness: 1 - (null ratio)
158
+ if available_cols:
159
+ null_ratio = df[available_cols].isnull().sum().sum() / (len(df) * len(available_cols))
160
+ completeness = 1.0 - null_ratio
161
+ else:
162
+ completeness = 1.0
163
+
164
+ # Consistency: check type correctness for numeric columns
165
+ consistency_scores = []
166
+ for col in domain_config.numeric_columns:
167
+ if col in df.columns:
168
+ valid = df[col].apply(lambda x: _is_numeric(x)).mean()
169
+ consistency_scores.append(valid)
170
+ consistency = float(np.mean(consistency_scores)) if consistency_scores else 1.0
171
+
172
+ # Uniqueness: 1 - (duplicate ratio)
173
+ if len(df) > 0:
174
+ n_dupes = df.duplicated(subset=available_cols[:5], keep='first').sum()
175
+ uniqueness = 1.0 - (n_dupes / len(df))
176
+ else:
177
+ uniqueness = 1.0
178
+
179
+ overall = 0.40 * completeness + 0.35 * consistency + 0.25 * uniqueness
180
+
181
+ return {
182
+ "completeness": round(completeness, 4),
183
+ "consistency": round(consistency, 4),
184
+ "uniqueness": round(uniqueness, 4),
185
+ "overall": round(overall, 4),
186
+ }
187
+
188
+
189
+ def _is_numeric(val) -> bool:
190
+ """Check if a value is numeric (or null, which is valid)."""
191
+ if pd.isna(val):
192
+ return True
193
+ try:
194
+ float(val)
195
+ return True
196
+ except (ValueError, TypeError):
197
+ return False
198
+
199
+
200
+ def compute_dq_score_with_lfs(df: pd.DataFrame, domain: str,
201
+ lfs: list) -> float:
202
+ """Compute DQ score using Snorkel-style labeling functions with majority vote."""
203
+ if not lfs or len(df) == 0:
204
+ config = DOMAINS.get(domain)
205
+ if config:
206
+ return compute_dq_score(df, config)["overall"]
207
+ return 0.5
208
+
209
+ ABSTAIN, BAD, GOOD = -1, 0, 1
210
+ row_scores = []
211
+
212
+ for _, row in df.iterrows():
213
+ votes = []
214
+ for lf in lfs:
215
+ try:
216
+ vote = lf(row)
217
+ if vote != ABSTAIN:
218
+ votes.append(vote)
219
+ except Exception:
220
+ continue
221
+ if votes:
222
+ row_scores.append(sum(v == GOOD for v in votes) / len(votes))
223
+ else:
224
+ row_scores.append(0.5)
225
+
226
+ return float(np.mean(row_scores))
227
+
228
+
229
+ def format_preview(df: pd.DataFrame, n: int = 5) -> str:
230
+ """Format first n rows as a text table."""
231
+ return df.head(n).to_string(index=False, max_colwidth=30)
232
+
233
+
234
+ def format_columns_info(df: pd.DataFrame, domain_config: DomainConfig) -> str:
235
+ """Format column info: name, dtype, null count."""
236
+ lines = []
237
+ for col in df.columns:
238
+ null_count = df[col].isnull().sum()
239
+ dtype = str(df[col].dtype)
240
+ expected = "expected" if col in domain_config.columns else "extra"
241
+ lines.append(f"{col}: {dtype}, nulls={null_count} ({expected})")
242
+ return "\n".join(lines)
environments/shared/personas.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generalized personas for domain-independent question answering."""
2
+
3
+ import re
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class Persona(BaseModel):
9
+ name: str
10
+ role: str
11
+ focus_areas: list[str]
12
+ language_style: str
13
+ keywords: list[str]
14
+ anti_keywords: list[str]
15
+
16
+
17
+ PERSONAS = [
18
+ Persona(
19
+ name="Executive",
20
+ role="executive",
21
+ focus_areas=["costs", "ROI", "strategic risk", "portfolio trends", "year-over-year"],
22
+ language_style="strategic-financial",
23
+ keywords=["revenue", "cost", "ROI", "risk", "trend", "quarter",
24
+ "year-over-year", "impact", "budget", "margin", "growth"],
25
+ anti_keywords=["I think", "maybe", "um", "idk"],
26
+ ),
27
+ Persona(
28
+ name="Manager",
29
+ role="manager",
30
+ focus_areas=["team performance", "operational health", "process bottlenecks", "capacity"],
31
+ language_style="operational-actionable",
32
+ keywords=["team", "performance", "bottleneck", "capacity", "SLA",
33
+ "process", "action", "priority", "escalation", "delivery"],
34
+ anti_keywords=["shareholder", "valuation", "IPO"],
35
+ ),
36
+ Persona(
37
+ name="Individual Contributor",
38
+ role="ic",
39
+ focus_areas=["personal tasks", "deadlines", "what to do next", "simple explanations"],
40
+ language_style="plain-personal",
41
+ keywords=["my", "I should", "next step", "deadline", "help",
42
+ "understand", "priority", "task", "assigned"],
43
+ anti_keywords=["KPI", "ROI", "portfolio", "strategic", "EBITDA"],
44
+ ),
45
+ ]
46
+
47
+ PERSONA_MAP = {p.role: p for p in PERSONAS}
48
+
49
+
50
+ def get_persona(role: str) -> Persona:
51
+ """Get a persona by role name."""
52
+ return PERSONA_MAP[role]
53
+
54
+
55
+ def score_persona_alignment(answer: str, persona: Persona) -> float:
56
+ """Score how well an answer aligns with a persona's communication style.
57
+
58
+ Returns a float 0-1 based on:
59
+ - Keyword density (presence of expected terms)
60
+ - Anti-keyword penalty (presence of terms that don't fit)
61
+ - Formality check (matches language style)
62
+ """
63
+ answer_lower = answer.lower()
64
+ words = re.findall(r'\w+', answer_lower)
65
+ word_count = max(len(words), 1)
66
+
67
+ # Keyword scoring: fraction of persona keywords found
68
+ keyword_hits = sum(1 for kw in persona.keywords if kw.lower() in answer_lower)
69
+ keyword_score = min(keyword_hits / max(len(persona.keywords) * 0.3, 1), 1.0)
70
+
71
+ # Anti-keyword penalty
72
+ anti_hits = sum(1 for akw in persona.anti_keywords if akw.lower() in answer_lower)
73
+ anti_penalty = min(anti_hits * 0.15, 0.5)
74
+
75
+ # Formality check
76
+ formality_score = _check_formality(answer, persona.language_style)
77
+
78
+ # Combine: 50% keywords, 20% formality, 30% base (minus anti-penalty)
79
+ raw_score = 0.50 * keyword_score + 0.20 * formality_score + 0.30 - anti_penalty
80
+ return round(max(0.0, min(1.0, raw_score)), 4)
81
+
82
+
83
+ def _check_formality(text: str, style: str) -> float:
84
+ """Check if text formality matches the expected language style."""
85
+ text_lower = text.lower()
86
+
87
+ if style == "strategic-financial":
88
+ indicators = ["percent", "%", "million", "billion", "quarter", "fiscal",
89
+ "forecast", "benchmark", "metric"]
90
+ hits = sum(1 for ind in indicators if ind in text_lower)
91
+ return min(hits / 2.0, 1.0)
92
+
93
+ elif style == "operational-actionable":
94
+ indicators = ["action", "recommend", "should", "priority", "next steps",
95
+ "immediate", "plan", "schedule"]
96
+ hits = sum(1 for ind in indicators if ind in text_lower)
97
+ return min(hits / 2.0, 1.0)
98
+
99
+ elif style == "plain-personal":
100
+ # Plain style rewards shorter sentences and simple language
101
+ sentences = text.split(".")
102
+ avg_len = sum(len(s.split()) for s in sentences) / max(len(sentences), 1)
103
+ return 1.0 if avg_len < 20 else max(0.0, 1.0 - (avg_len - 20) / 30)
104
+
105
+ return 0.5
environments/shared/reward_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward computation with cached downstream signals."""
2
+
3
+ # Cached downstream signals mapping quality buckets to historical scores.
4
+ # These represent how well downstream stages perform given upstream quality.
5
+ DOWNSTREAM_CACHE: dict[str, float] = {
6
+ "excellent": 0.95, # DQ > 0.90 or coverage > 0.80
7
+ "good": 0.75, # DQ 0.70-0.90 or coverage 0.50-0.80
8
+ "fair": 0.50, # DQ 0.50-0.70 or coverage 0.30-0.50
9
+ "poor": 0.20, # DQ < 0.50 or coverage < 0.30
10
+ }
11
+
12
+
13
+ def _get_downstream_bucket(score: float) -> str:
14
+ """Map a score to a downstream quality bucket."""
15
+ if score > 0.90:
16
+ return "excellent"
17
+ elif score > 0.70:
18
+ return "good"
19
+ elif score > 0.50:
20
+ return "fair"
21
+ return "poor"
22
+
23
+
24
+ def cleaning_reward(dq_score: float, downstream_bucket: str = "") -> float:
25
+ """Compute cleaning stage reward.
26
+
27
+ 0.70 * dq_score + 0.30 * downstream_signal
28
+ """
29
+ if not downstream_bucket:
30
+ downstream_bucket = _get_downstream_bucket(dq_score)
31
+ downstream = DOWNSTREAM_CACHE.get(downstream_bucket, 0.5)
32
+ return round(0.70 * dq_score + 0.30 * downstream, 4)
33
+
34
+
35
+ def enrichment_reward(coverage: float, downstream_bucket: str = "") -> float:
36
+ """Compute enrichment stage reward.
37
+
38
+ 0.50 * coverage + 0.50 * downstream_signal
39
+ """
40
+ if not downstream_bucket:
41
+ downstream_bucket = _get_downstream_bucket(coverage)
42
+ downstream = DOWNSTREAM_CACHE.get(downstream_bucket, 0.5)
43
+ return round(0.50 * coverage + 0.50 * downstream, 4)
44
+
45
+
46
+ def answering_reward(faithfulness: float, persona_relevance: float,
47
+ patronus_score: float | None = None) -> float:
48
+ """Compute answering stage reward.
49
+
50
+ Without Patronus: 0.30 * faithfulness + 0.70 * persona_relevance
51
+ With Patronus: 0.40 * patronus_faithfulness + 0.60 * persona_relevance
52
+ """
53
+ if patronus_score is not None:
54
+ return round(0.40 * patronus_score + 0.60 * persona_relevance, 4)
55
+ return round(0.30 * faithfulness + 0.70 * persona_relevance, 4)
models.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models for the DataSage Answering Environment."""
2
+
3
+ from openenv.core.env_server.types import Action, Observation
4
+ from pydantic import Field
5
+
6
+
7
+ class AnsweringAction(Action):
8
+ """Action for the Answering environment - a generated answer."""
9
+
10
+ answer: str = Field(..., description="The generated answer text")
11
+ cited_columns: list[str] = Field(default_factory=list, description="Data columns cited")
12
+ reasoning: str = Field(default="", description="Chain-of-thought reasoning")
13
+
14
+
15
+ class AnsweringObservation(Observation):
16
+ """Observation from the Answering environment - context for generating an answer."""
17
+
18
+ domain: str = Field(default="")
19
+ dataset_summary: str = Field(default="", description="Statistical summary of enriched data")
20
+ persona: str = Field(default="", description="Executive|Manager|Individual Contributor")
21
+ persona_description: str = Field(default="", description="What this persona cares about")
22
+ question: str = Field(default="", description="The question to answer")
23
+ available_columns: list[str] = Field(default_factory=list)
24
+ column_stats: str = Field(default="", description="Relevant column statistics")
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: datasage_answering
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
pyproject.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-datasage-answering"
7
+ version = "0.1.0"
8
+ description = "DataSage Answering environment for OpenEnv - persona-aligned enterprise data Q&A"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "openenv-core[core]>=0.2.1",
12
+ "pandas>=2.0",
13
+ "numpy>=1.24",
14
+ ]
15
+
16
+ [project.optional-dependencies]
17
+ dev = [
18
+ "pytest>=8.0.0",
19
+ "pytest-cov>=4.0.0",
20
+ ]
21
+ patronus = [
22
+ "patronus>=0.1.0",
23
+ ]
24
+
25
+ [project.scripts]
26
+ server = "datasage_answering.server.app:main"
27
+
28
+ [tool.setuptools]
29
+ include-package-data = true
30
+ packages = ["datasage_answering", "datasage_answering.server"]
31
+ package-dir = { "datasage_answering" = ".", "datasage_answering.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Answering environment server components."""
2
+
3
+ from .answering_environment import AnsweringEnvironment
4
+
5
+ __all__ = ["AnsweringEnvironment"]
server/answering_environment.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataSage Answering Environment Implementation.
3
+
4
+ A single-step RL environment where the agent must answer enterprise data
5
+ questions tailored to a specific persona (Executive, Manager, IC) using
6
+ enriched data context across 4 domains (HR, Sales, PM, IT Ops).
7
+ """
8
+
9
+ import os
10
+ import random
11
+ import sys
12
+ from uuid import uuid4
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ # Allow imports from the project root so shared modules are reachable.
18
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
19
+
20
+ from environments.shared.domains import DOMAINS
21
+ from environments.shared.enterprise_data import load_domain_data, format_preview
22
+ from environments.shared.personas import PERSONAS, score_persona_alignment
23
+ from environments.shared.reward_utils import answering_reward
24
+
25
+ from models import AnsweringAction, AnsweringObservation
26
+ from openenv.core.env_server.interfaces import Environment
27
+ from openenv.core.env_server.types import State
28
+
29
+
30
+ class AnsweringEnvironment(Environment):
31
+ """
32
+ Answering environment: the agent receives a data context, a persona,
33
+ and a question, then must produce a faithful, persona-aligned answer.
34
+
35
+ This is a single-step episode: the agent submits one answer and
36
+ receives a terminal reward.
37
+ """
38
+
39
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
40
+
41
+ def __init__(self):
42
+ """Initialise the answering environment."""
43
+ self._state = State(episode_id=str(uuid4()), step_count=0)
44
+ self._df: pd.DataFrame = pd.DataFrame()
45
+ self._domain_name: str = ""
46
+ self._domain_config = None
47
+ self._persona = None
48
+ self._question: str = ""
49
+
50
+ # ------------------------------------------------------------------
51
+ # reset
52
+ # ------------------------------------------------------------------
53
+ def reset(self) -> AnsweringObservation:
54
+ """Pick a random domain, persona, and question; load enriched data summary."""
55
+ self._state = State(episode_id=str(uuid4()), step_count=0)
56
+
57
+ # Pick random domain
58
+ self._domain_name = random.choice(list(DOMAINS.keys()))
59
+ self._domain_config = DOMAINS[self._domain_name]
60
+
61
+ # Pick random persona
62
+ self._persona = random.choice(PERSONAS)
63
+
64
+ # Pick a domain-appropriate question
65
+ self._question = random.choice(self._domain_config.example_questions)
66
+
67
+ # Load data (simulate enriched data by loading raw)
68
+ self._df = load_domain_data(self._domain_name, sample_size=100)
69
+
70
+ # Compute dataset summary (basic stats for numeric columns)
71
+ summary_parts = []
72
+ for col in self._domain_config.numeric_columns:
73
+ if col in self._df.columns:
74
+ stats = self._df[col].describe()
75
+ summary_parts.append(
76
+ f"{col}: mean={stats['mean']:.1f}, std={stats['std']:.1f}, "
77
+ f"min={stats['min']:.1f}, max={stats['max']:.1f}"
78
+ )
79
+ dataset_summary = "\n".join(summary_parts) if summary_parts else "No numeric summary available."
80
+
81
+ # Compute column stats (first 12 columns)
82
+ col_stats = []
83
+ for col in self._df.columns[:12]:
84
+ if self._df[col].dtype in ['int64', 'float64']:
85
+ col_stats.append(f"{col}: {self._df[col].describe().to_dict()}")
86
+ else:
87
+ col_stats.append(f"{col}: {self._df[col].value_counts().head(5).to_dict()}")
88
+ column_stats_str = "\n".join(col_stats)
89
+
90
+ # Persona description
91
+ persona_desc = (
92
+ f"Role: {self._persona.role}. "
93
+ f"Focus areas: {', '.join(self._persona.focus_areas)}. "
94
+ f"Language style: {self._persona.language_style}."
95
+ )
96
+
97
+ return AnsweringObservation(
98
+ domain=self._domain_name,
99
+ dataset_summary=dataset_summary,
100
+ persona=self._persona.name,
101
+ persona_description=persona_desc,
102
+ question=self._question,
103
+ available_columns=list(self._df.columns),
104
+ column_stats=column_stats_str,
105
+ done=False,
106
+ reward=0.0,
107
+ )
108
+
109
+ # ------------------------------------------------------------------
110
+ # step
111
+ # ------------------------------------------------------------------
112
+ def step(self, action: AnsweringAction) -> AnsweringObservation: # type: ignore[override]
113
+ """Evaluate the answer and return a terminal observation with reward."""
114
+ self._state.step_count += 1
115
+
116
+ # Compute faithfulness
117
+ faithfulness = self._compute_faithfulness(action, self._df)
118
+
119
+ # Compute persona relevance
120
+ persona_relevance = score_persona_alignment(action.answer, self._persona)
121
+
122
+ # Optional: Patronus hallucination check
123
+ context = (
124
+ f"Domain: {self._domain_name}\n"
125
+ f"Question: {self._question}\n"
126
+ f"Available columns: {list(self._df.columns)}\n"
127
+ f"Data sample:\n{format_preview(self._df)}"
128
+ )
129
+ patronus_score = self._get_patronus_score(action, context)
130
+
131
+ # Compute final reward
132
+ reward = answering_reward(faithfulness, persona_relevance, patronus_score)
133
+
134
+ return AnsweringObservation(
135
+ domain=self._domain_name,
136
+ dataset_summary="",
137
+ persona=self._persona.name,
138
+ persona_description="",
139
+ question=self._question,
140
+ available_columns=list(self._df.columns),
141
+ column_stats="",
142
+ done=True,
143
+ reward=reward,
144
+ metadata={
145
+ "faithfulness": faithfulness,
146
+ "persona_relevance": persona_relevance,
147
+ "patronus_score": patronus_score,
148
+ "step": self._state.step_count,
149
+ },
150
+ )
151
+
152
+ # ------------------------------------------------------------------
153
+ # state property
154
+ # ------------------------------------------------------------------
155
+ @property
156
+ def state(self) -> State:
157
+ """Return current environment state."""
158
+ return self._state
159
+
160
+ # ------------------------------------------------------------------
161
+ # scoring helpers
162
+ # ------------------------------------------------------------------
163
+ def _compute_faithfulness(self, action: AnsweringAction, df: pd.DataFrame) -> float:
164
+ """Score faithfulness based on cited columns and value references."""
165
+ score = 0.0
166
+
167
+ # Check cited columns exist
168
+ valid_cols = [c for c in action.cited_columns if c in df.columns]
169
+ if action.cited_columns:
170
+ score += 0.5 * (len(valid_cols) / len(action.cited_columns))
171
+ else:
172
+ score += 0.1 # penalty for not citing
173
+
174
+ # Check answer mentions real values
175
+ answer_lower = action.answer.lower()
176
+ for col in valid_cols[:3]:
177
+ sample_vals = df[col].dropna().astype(str).head(10).tolist()
178
+ if any(str(v).lower() in answer_lower for v in sample_vals):
179
+ score += 0.15
180
+
181
+ return min(score, 1.0)
182
+
183
+ def _get_patronus_score(self, action: AnsweringAction, context: str):
184
+ """Optionally call Patronus Lynx for hallucination checking."""
185
+ api_key = os.environ.get("PATRONUS_API_KEY")
186
+ if not api_key:
187
+ return None
188
+ try:
189
+ from patronus import Client
190
+ client = Client(api_key=api_key)
191
+ result = client.evaluate(
192
+ evaluator="lynx-small",
193
+ criteria="patronus:hallucination",
194
+ evaluated_model_output=action.answer,
195
+ task_context=context,
196
+ )
197
+ return float(result.score)
198
+ except Exception:
199
+ return None
server/app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for the DataSage Answering Environment.
3
+
4
+ This module creates an HTTP server that exposes the AnsweringEnvironment
5
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
6
+
7
+ Endpoints:
8
+ - POST /reset: Reset the environment
9
+ - POST /step: Execute an action
10
+ - GET /state: Get current environment state
11
+ - GET /schema: Get action/observation schemas
12
+ - WS /ws: WebSocket endpoint for persistent sessions
13
+
14
+ Usage:
15
+ # Development (with auto-reload):
16
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
17
+
18
+ # Production:
19
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
20
+
21
+ # Or run directly:
22
+ python -m server.app
23
+ """
24
+
25
+ try:
26
+ from openenv.core.env_server.http_server import create_app
27
+ except Exception as e: # pragma: no cover
28
+ raise ImportError(
29
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
30
+ ) from e
31
+
32
+ # Import from local models.py (PYTHONPATH includes /app/env in Docker)
33
+ from models import AnsweringAction, AnsweringObservation
34
+
35
+ from .answering_environment import AnsweringEnvironment
36
+
37
+
38
+ # Create the app with web interface
39
+ app = create_app(
40
+ AnsweringEnvironment,
41
+ AnsweringAction,
42
+ AnsweringObservation,
43
+ env_name="datasage_answering",
44
+ max_concurrent_envs=4,
45
+ )
46
+
47
+
48
+ def main(host: str = "0.0.0.0", port: int = 8000):
49
+ """
50
+ Entry point for direct execution via uv run or python -m.
51
+
52
+ Args:
53
+ host: Host address to bind to (default: "0.0.0.0")
54
+ port: Port number to listen on (default: 8000)
55
+ """
56
+ import uvicorn
57
+
58
+ uvicorn.run(app, host=host, port=port)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ import argparse
63
+
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument("--port", type=int, default=8000)
66
+ args = parser.parse_args()
67
+ main(port=args.port)
server/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ pandas>=2.0
5
+ numpy>=1.24
uv.lock ADDED
The diff for this file is too large to render. See raw diff