File size: 9,120 Bytes
30bf68a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
from __future__ import annotations

import argparse
import asyncio
import os
from typing import Any

from openai import OpenAI

from env.environment import CICDDebuggerEnvironment, REQUIRED_TOOLS
from inference.metrics import EpisodeMetrics
from inference.model_wrapper import ModelWrapper, score_action_candidate
from inference.prompts import heuristic_action
from inference.visualize import save_metrics_json, save_reward_curve, save_success_rate_history


API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
API_KEY = OPENAI_API_KEY or HF_TOKEN
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
DEFAULT_TASK_ID = os.getenv("MY_ENV_V4_TASK", "easy-command-typo")
DEFAULT_BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "cicd_debugger_env")

MAX_STEPS_DEFAULT = int(os.getenv("MAX_STEPS", "8"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "120"))
OFFLINE_INFERENCE = os.getenv("OFFLINE_INFERENCE", "0") == "1"
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.1"))


def log_start(task: str, env_name: str, model: str) -> None:
    print(f"[START] task={_single_line(task)} env={_single_line(env_name)} model={_single_line(model)}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
    done_val = str(done).lower()
    error_val = _single_line(error) if error else "null"
    action_val = _single_line(action)
    print(f"[STEP] step={step} action={action_val} reward={reward:.2f} done={done_val} error={error_val}", flush=True)


def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
    rewards_str = ",".join(f"{value:.2f}" for value in rewards)
    print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)


def _single_line(value: Any) -> str:
    return " ".join(str(value).replace("\n", " ").replace("\r", " ").split())


def _is_hacking_action(action_text: str) -> bool:
    value = (action_text or "").lower()
    patterns = (
        "if: false",
        "when: never",
        "echo \"tests passed\"",
        "echo 'tests passed'",
        "exit 0",
        "force success",
        "status: success",
    )
    return any(token in value for token in patterns)


def _extract_error(info: dict[str, Any] | None) -> str | None:
    if not info:
        return None
    error = info.get("error")
    return str(error) if error else None


def _extract_observation_fields(observation: dict[str, Any]) -> tuple[str, str, list[str]]:
    config_text = str(observation.get("config") or "")
    error_message = str(observation.get("error_message") or "")
    tools = [str(item) for item in (observation.get("available_tools") or REQUIRED_TOOLS)]
    return config_text, error_message, tools


def _tool_from_action(action_text: str) -> str:
    return str(action_text or "").split(":", 1)[0].strip().lower()


def _is_action_allowed(action_text: str, available_tools: list[str]) -> bool:
    return _tool_from_action(action_text) in {tool.lower() for tool in available_tools}


def _normalize_action(action_text: str, available_tools: list[str], fallback: str) -> str:
    action = str(action_text or "").strip()
    if not action:
        return fallback

    aliases = {
        "run_stage": "run_pipeline_stage",
        "validate": "validate_fix",
        "submit": "submit_solution",
        "submit_fix": "submit_solution",
    }
    tool = _tool_from_action(action)
    normalized_tool = aliases.get(tool, tool)
    if normalized_tool != tool:
        suffix = action.split(":", 1)[1].strip() if ":" in action else ""
        action = f"{normalized_tool}: {suffix}" if suffix else normalized_tool

    if _is_action_allowed(action, available_tools):
        return action

    return fallback


def _select_action(
    model_wrapper: ModelWrapper,
    step: int,
    config_text: str,
    error_message: str,
    history: list[str],
    available_actions: list[str],
    policy_mode: str,
    trajectories: int,
) -> str:
    mode = (policy_mode or "imp").lower()
    fallback = heuristic_action(config_text, error_message, available_actions, history)

    if mode == "sft":
        return _normalize_action(fallback, available_actions, fallback)

    if mode == "direct":
        action = model_wrapper.generate_action(
            step=step,
            config_text=config_text,
            error_message=error_message,
            history=history,
            available_actions=available_actions,
        )
        return _normalize_action(action, available_actions, fallback)

    candidates = model_wrapper.generate_candidates(
        step=step,
        config_text=config_text,
        error_message=error_message,
        history=history,
        count=max(1, int(trajectories)),
        available_actions=available_actions,
    )

    if not candidates:
        return _normalize_action(fallback, available_actions, fallback)

    observation_text = f"{config_text}\n{error_message}"
    best = max(candidates, key=lambda item: score_action_candidate(observation_text, item, _is_hacking_action))
    return _normalize_action(best, available_actions, fallback)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run CI/CD debugger inference loop")
    parser.add_argument("--max-steps", type=int, default=MAX_STEPS_DEFAULT)
    parser.add_argument("--task", default=DEFAULT_TASK_ID)
    parser.add_argument("--benchmark", default=DEFAULT_BENCHMARK)
    parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default=None)
    parser.add_argument("--offline", action="store_true", default=OFFLINE_INFERENCE)
    parser.add_argument("--force-local-env", action="store_true", default=False)
    parser.add_argument("--policy-mode", choices=["sft", "imp", "direct"], default="imp")
    parser.add_argument("--trajectories", type=int, default=3)
    return parser.parse_args()


async def run_episode(args: argparse.Namespace) -> int:
    history: list[str] = []
    steps_taken = 0
    success = False
    episode_completed_cleanly = False
    metrics = EpisodeMetrics()

    env = CICDDebuggerEnvironment(max_steps=max(1, int(args.max_steps)))

    offline_mode = bool(args.offline or not API_KEY)
    client: OpenAI | None = None
    if not offline_mode:
        try:
            client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
        except Exception:
            client = None
            offline_mode = True

    model_wrapper = ModelWrapper(
        client=client,
        model_name=MODEL_NAME,
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        offline=offline_mode,
    )

    log_start(task=str(args.task), env_name=str(args.benchmark), model=MODEL_NAME)

    try:
        observation = await env.reset(task_id=str(args.task), difficulty=args.difficulty)

        for step in range(1, max(1, int(args.max_steps)) + 1):
            config_text, error_message, available_tools = _extract_observation_fields(observation)

            action_text = _select_action(
                model_wrapper=model_wrapper,
                step=step,
                config_text=config_text,
                error_message=error_message,
                history=history,
                available_actions=available_tools,
                policy_mode=str(args.policy_mode),
                trajectories=max(1, int(args.trajectories)),
            )

            observation, reward, done, info = await env.step(action_text)
            step_error = _extract_error(info)

            metrics.add_step(action=action_text, reward=float(reward), error=step_error, done=bool(done))
            steps_taken = step

            log_step(step=step, action=action_text, reward=float(reward), done=bool(done), error=step_error)
            history.append(f"step={step} action={_single_line(action_text)} reward={float(reward):.2f}")

            if done:
                episode_completed_cleanly = step_error is None and not _is_hacking_action(action_text)
                break

    except Exception as exc:
        success = False
        if not metrics.rewards:
            metrics.add_step(action="system_error", reward=0.0, error=str(exc), done=True)
    finally:
        score = max(0.0, min(1.0, float(metrics.average_reward)))
        success = episode_completed_cleanly and score >= SUCCESS_SCORE_THRESHOLD

        try:
            save_reward_curve(metrics.rewards)
            save_metrics_json(metrics.summary())
            save_success_rate_history([success])
        except Exception:
            pass

        try:
            await env.close()
        except Exception:
            pass

        log_end(success=success, steps=steps_taken, score=score, rewards=metrics.rewards)

    return 0


def main() -> int:
    args = parse_args()
    return asyncio.run(run_episode(args))


if __name__ == "__main__":
    raise SystemExit(main())