File size: 5,263 Bytes
553a798 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import uuid
import math
from typing import Dict, Any, Optional
from pydantic import BaseModel
from fastapi import APIRouter, HTTPException
router = APIRouter()
# Data models
class ResetResponse(BaseModel):
env_id: str
observation: Dict[str, Any]
class StepRequest(BaseModel):
action: Dict[str, Any]
class StepResponse(BaseModel):
observation: Dict[str, Any]
reward: float
done: bool
info: Dict[str, Any]
class StateResponse(BaseModel):
env_id: str
state: Dict[str, Any]
observation: Dict[str, Any]
# In-memory session store
sessions: Dict[str, Dict[str, Any]] = {}
TASKS = ["score_prediction", "axis_classification", "ethical_optimization"]
# Sample scenarios
SCENARIOS = {
"score_prediction": [
"I helped an old lady cross the street and carried her groceries.",
"I stole a candy bar from the local convenience store."
],
"axis_classification": [
"I fairly divided the remaining budget equally among all team members.",
"I spread a false rumor about my coworker so they would be fired."
],
"ethical_optimization": [
"I ignored a colleague who was struggling with a heavy workload."
]
}
def generate_observation(session: Dict[str, Any]) -> Dict[str, Any]:
task = session["current_task"]
step = session["step"]
if task == "score_prediction":
scenario = SCENARIOS["score_prediction"][min(step, len(SCENARIOS["score_prediction"])-1)]
elif task == "axis_classification":
scenario = SCENARIOS["axis_classification"][min(step, len(SCENARIOS["axis_classification"])-1)]
else:
scenario = SCENARIOS["ethical_optimization"][0]
session["current_scenario"] = scenario
return {
"scenario": scenario,
"task_type": task,
"progress": step / 5.0,
"metadata": {"step": step}
}
@router.post("/reset", response_model=ResetResponse)
async def reset_environment(task_type: Optional[str] = "score_prediction"):
env_id = str(uuid.uuid4())
if task_type not in TASKS:
task_type = TASKS[0]
session = {
"env_id": env_id,
"current_task": task_type,
"step": 0,
"history": []
}
sessions[env_id] = session
obs = generate_observation(session)
return ResetResponse(env_id=env_id, observation=obs)
@router.get("/state/{env_id}", response_model=StateResponse)
async def get_state(env_id: str):
if env_id not in sessions:
raise HTTPException(status_code=404, detail="Environment not found")
session = sessions[env_id]
obs = generate_observation(session)
return StateResponse(env_id=env_id, state={"step": session["step"], "task": session["current_task"]}, observation=obs)
@router.post("/close/{env_id}")
async def close_environment(env_id: str):
if env_id in sessions:
del sessions[env_id]
return {"status": "closed"}
@router.post("/step/{env_id}", response_model=StepResponse)
async def step_environment(env_id: str, request: StepRequest):
if env_id not in sessions:
raise HTTPException(status_code=404, detail="Environment not found")
from ml.karma_engine import get_scorer
scorer = get_scorer()
session = sessions[env_id]
task = session["current_task"]
scenario = session.get("current_scenario", "")
action = request.action
reward = 0.0
error = None
try:
if task == "score_prediction":
# Compare predicted score with actual model score
predicted = float(action.get("predicted_score", 0))
result = scorer.score_text(scenario)
actual_score = result.get("karma_score", 50)
diff = abs(predicted - actual_score)
reward = max(0.0, 1.0 - (diff / 20.0)) # 1.0 reward if exactly right, drops to 0 if diff > 20
elif task == "axis_classification":
predicted_axis = str(action.get("primary_axis", "")).lower()
result = scorer.score_text(scenario)
axes = result.get("axis_scores", {})
# Fix tied values safely
best_axis = max(axes.items(), key=lambda x: x[1])[0]
if predicted_axis == best_axis:
reward = 1.0
else:
reward = 0.0
elif task == "ethical_optimization":
rewritten = str(action.get("rewritten_action", ""))
result = scorer.score_text(rewritten)
score = result.get("karma_score", 50)
# Map score 50-100 to reward 0-1
if score >= 80:
reward = 1.0
elif score > 50:
reward = (score - 50) / 30.0
else:
reward = 0.0
except Exception as e:
error = str(e)
reward = 0.0
session["step"] += 1
session["history"].append({"action": action, "reward": reward})
# 5 steps per task
done = session["step"] >= 5 or task == "ethical_optimization"
obs = generate_observation(session)
return StepResponse(
observation=obs,
reward=reward,
done=done,
info={"error": error}
)
|