File size: 4,407 Bytes
8097081
 
 
 
 
 
 
 
 
7c54da3
8097081
7c54da3
8097081
7c54da3
8097081
 
 
 
 
 
7c54da3
8097081
 
 
7c54da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c54da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8097081
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# inference.py
import asyncio
import json
import os
from typing import List

from openai import OpenAI
import httpx

API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or os.environ.get("OPENAI_API_KEY", "dummy")
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
TASKS = os.environ.get("TASKS", "easy,medium,hard")
MAX_STEPS = int(os.environ.get("MAX_STEPS", "5"))
SUCCESS_SCORE_THRESHOLD = float(os.environ.get("SUCCESS_SCORE_THRESHOLD", "0.7"))
MAX_TOTAL_REWARD = float(os.environ.get("MAX_TOTAL_REWARD", "1.0"))


def log_start(task, env, model):
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step, action, reward, done, error):
    err = "null" if error is None else str(error)
    done_str = "true" if done else "false"
    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} done={done_str} error={err}",
        flush=True,
    )


def log_end(success, steps, rewards):
    success_str = "true" if success else "false"
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END] success={success_str} steps={steps} rewards={rewards_str}",
        flush=True,
    )


def get_model_message(client: OpenAI, observation: dict, history: List[str]) -> str:
    prompt = f"""
You are debugging a PyTorch training job. Respond ONLY with valid JSON matching this exact schema:
{{
  "current_hypothesis": {{"bug_type": "<string>", "affected_file": "<string>", "confidence": <0.0-1.0>}},
  "investigation_action": {{"action": "reveal_file", "target": "<filename>"}},
  "commit_diagnosis": false,
  "final_diagnosis": null
}}

Valid action types: reveal_file, extend_loss_curve, extend_gpu_profile, reveal_log_chunk, run_diagnostic
Valid bug types: missing_zero_grad, data_leakage, memory_leak, learning_rate_too_high, gradient_explosion

Observation:
{json.dumps(observation)[:8000]}
History: {history}
"""
    completion = client.chat.completions.create(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_tokens=500,
    )
    return (completion.choices[0].message.content or "").strip()


async def main():
    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
    tasks = [task.strip() for task in TASKS.split(",") if task.strip()]

    for task in tasks:
        rewards = []
        history = []
        steps_taken = 0

        log_start(task=task, env="pytorch-debug-env", model=MODEL_NAME)

        async with httpx.AsyncClient(timeout=60.0) as session:
            reset_resp = await session.post(f"{ENV_URL}/reset", params={"task_id": task})
            reset_resp.raise_for_status()
            result = reset_resp.json()
            session_id = result.get("session_id")
            observation = result["observation"]

            for step in range(1, MAX_STEPS + 1):
                if result.get("done"):
                    break

                action_text = get_model_message(client, observation, history)
                try:
                    action_json = json.loads(action_text)
                    step_resp = await session.post(
                        f"{ENV_URL}/step",
                        params={"session_id": session_id},
                        json=action_json,
                    )
                    step_resp.raise_for_status()
                    result = step_resp.json()
                    reward = result.get("reward", 0.0)
                    done = result.get("done", False)
                    error = None
                    observation = result["observation"]
                except Exception as exc:
                    reward = 0.0
                    done = True
                    error = str(exc)

                rewards.append(reward)
                steps_taken = step
                log_step(step=step, action=action_text, reward=reward, done=done, error=error)
                history.append(f"step={step} reward={reward:.3f}")

                if done:
                    break

        score = min(max(rewards[-1] if rewards else 0.0, 0.0), 1.0)
        success = score >= SUCCESS_SCORE_THRESHOLD
        log_end(success=success, steps=steps_taken, rewards=rewards)


if __name__ == "__main__":
    asyncio.run(main())