File size: 5,263 Bytes
553a798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import uuid
import math
from typing import Dict, Any, Optional
from pydantic import BaseModel
from fastapi import APIRouter, HTTPException

router = APIRouter()

# Data models
class ResetResponse(BaseModel):
    env_id: str
    observation: Dict[str, Any]

class StepRequest(BaseModel):
    action: Dict[str, Any]

class StepResponse(BaseModel):
    observation: Dict[str, Any]
    reward: float
    done: bool
    info: Dict[str, Any]

class StateResponse(BaseModel):
    env_id: str
    state: Dict[str, Any]
    observation: Dict[str, Any]

# In-memory session store
sessions: Dict[str, Dict[str, Any]] = {}

TASKS = ["score_prediction", "axis_classification", "ethical_optimization"]

# Sample scenarios
SCENARIOS = {
    "score_prediction": [
        "I helped an old lady cross the street and carried her groceries.",
        "I stole a candy bar from the local convenience store."
    ],
    "axis_classification": [
        "I fairly divided the remaining budget equally among all team members.",
        "I spread a false rumor about my coworker so they would be fired."
    ],
    "ethical_optimization": [
        "I ignored a colleague who was struggling with a heavy workload."
    ]
}

def generate_observation(session: Dict[str, Any]) -> Dict[str, Any]:
    task = session["current_task"]
    step = session["step"]
    
    if task == "score_prediction":
        scenario = SCENARIOS["score_prediction"][min(step, len(SCENARIOS["score_prediction"])-1)]
    elif task == "axis_classification":
        scenario = SCENARIOS["axis_classification"][min(step, len(SCENARIOS["axis_classification"])-1)]
    else:
        scenario = SCENARIOS["ethical_optimization"][0]

    session["current_scenario"] = scenario
    
    return {
        "scenario": scenario,
        "task_type": task,
        "progress": step / 5.0,
        "metadata": {"step": step}
    }

@router.post("/reset", response_model=ResetResponse)
async def reset_environment(task_type: Optional[str] = "score_prediction"):
    env_id = str(uuid.uuid4())
    if task_type not in TASKS:
        task_type = TASKS[0]
        
    session = {
        "env_id": env_id,
        "current_task": task_type,
        "step": 0,
        "history": []
    }
    
    sessions[env_id] = session
    obs = generate_observation(session)
    return ResetResponse(env_id=env_id, observation=obs)

@router.get("/state/{env_id}", response_model=StateResponse)
async def get_state(env_id: str):
    if env_id not in sessions:
        raise HTTPException(status_code=404, detail="Environment not found")
    
    session = sessions[env_id]
    obs = generate_observation(session)
    return StateResponse(env_id=env_id, state={"step": session["step"], "task": session["current_task"]}, observation=obs)

@router.post("/close/{env_id}")
async def close_environment(env_id: str):
    if env_id in sessions:
        del sessions[env_id]
    return {"status": "closed"}

@router.post("/step/{env_id}", response_model=StepResponse)
async def step_environment(env_id: str, request: StepRequest):
    if env_id not in sessions:
        raise HTTPException(status_code=404, detail="Environment not found")
        
    from ml.karma_engine import get_scorer
    scorer = get_scorer()
    
    session = sessions[env_id]
    task = session["current_task"]
    scenario = session.get("current_scenario", "")
    action = request.action
    
    reward = 0.0
    error = None
    
    try:
        if task == "score_prediction":
            # Compare predicted score with actual model score
            predicted = float(action.get("predicted_score", 0))
            result = scorer.score_text(scenario)
            actual_score = result.get("karma_score", 50)
            
            diff = abs(predicted - actual_score)
            reward = max(0.0, 1.0 - (diff / 20.0)) # 1.0 reward if exactly right, drops to 0 if diff > 20
            
        elif task == "axis_classification":
            predicted_axis = str(action.get("primary_axis", "")).lower()
            result = scorer.score_text(scenario)
            axes = result.get("axis_scores", {})
            # Fix tied values safely
            best_axis = max(axes.items(), key=lambda x: x[1])[0]
            
            if predicted_axis == best_axis:
                reward = 1.0
            else:
                reward = 0.0
                
        elif task == "ethical_optimization":
            rewritten = str(action.get("rewritten_action", ""))
            result = scorer.score_text(rewritten)
            score = result.get("karma_score", 50)
            
            # Map score 50-100 to reward 0-1
            if score >= 80:
                reward = 1.0
            elif score > 50:
                reward = (score - 50) / 30.0
            else:
                reward = 0.0
                
    except Exception as e:
        error = str(e)
        reward = 0.0
        
    session["step"] += 1
    session["history"].append({"action": action, "reward": reward})
    
    # 5 steps per task
    done = session["step"] >= 5 or task == "ethical_optimization"
    
    obs = generate_observation(session)
    
    return StepResponse(
        observation=obs,
        reward=reward,
        done=done,
        info={"error": error}
    )