cicd-rl-agent / inference.py
Nikitasoni22's picture
training issue resolved
6fc8d4c
import json
import re
from typing import Dict, Any, List
from cicd_debug_env.env import CICDDebugEnv
from cicd_debug_env.models import Action, Observation
SYSTEM_PROMPT = """You are an autonomous AI agent operating in a CI/CD debugging environment.
Your objective is to iteratively diagnose and fix broken CI/CD pipelines using structured reasoning, tool-based actions, and feedback from the environment.
You function in a sequential decision loop: Observe → Reason → Decide → Act → Evaluate → Repeat
At each step you receive: pipeline configuration (YAML), execution logs, error messages, causal blame scores per step, memory bank hits from similar past failures, and available tools.
OPERATING RULES:
1. Always query memory bank before deciding — avoid actions that previously failed on similar errors
2. Assign blame scores to each pipeline step before acting — never edit unrelated steps
3. Output a confidence score with every action (0-1)
4. If confidence < 0.5: take a diagnostic action (read_logs, analyze_error) before attempting a fix
5. Before submit_solution: mentally verify the fix works on 3 slight variants
6. After fixing: generate a structured diff explanation
OUTPUT FORMAT (strictly enforced):
[START] task=<id> env=cicd_debug model=<name>
[STEP] step=<n> action=<type> params=<json> confidence=<float> reward=<float> done=<bool> error=<msg|null>
[END] success=<true|false> steps=<n> score=<float> rewards=<r1,r2,...>
"""
class CICDAgent:
def __init__(self, model_name="unsloth/Qwen2.5-3B-Instruct", use_api=False):
self.model_name = model_name
self.use_api = use_api
self.env = CICDDebugEnv()
def build_prompt(self, observation: Observation) -> str:
prompt = f"{SYSTEM_PROMPT}\n\n"
prompt += "--- CURRENT STATE ---\n"
prompt += f"Error: {observation.error_message}\n"
prompt += "Logs:\n" + "\n".join(observation.logs) + "\n"
prompt += f"YAML:\n{observation.pipeline_yaml}\n"
prompt += f"Causal Blame Scores: {json.dumps(observation.step_blame_scores)}\n"
prompt += f"Available Actions: {', '.join(observation.available_actions)}\n"
prompt += "\n--- EPISODE HISTORY ---\n"
for i, h in enumerate(observation.episode_history):
prompt += f"Step {i+1}: Action={h['action'].action_type if hasattr(h.get('action'), 'action_type') else 'unknown'} Reward={h.get('reward', 0.0)}\n"
prompt += "\n--- MEMORY BANK HITS ---\n"
for hit in observation.memory_hits:
prompt += f"Error Fingerprint: {hit.get('error_fingerprint', '')} | Action: {hit['action'].action_type if hasattr(hit.get('action'), 'action_type') else 'unknown'} | Reward: {hit.get('reward', 0.0)}\n"
prompt += "\nGenerate your next action following the strictly enforced OUTPUT FORMAT."
return prompt
def parse_action(self, response: str) -> Action:
action_match = re.search(r"action=(\w+)", response)
params_match = re.search(r"params=({.*?})", response)
conf_match = re.search(r"confidence=([0-9.]+)", response)
action_type = action_match.group(1) if action_match else "analyze_error"
params_str = params_match.group(1) if params_match else "{}"
try:
params = json.loads(params_str)
except:
params = {}
confidence = float(conf_match.group(1)) if conf_match else 0.5
if confidence < 0.5 and action_type not in ["read_logs", "analyze_error"]:
action_type = "analyze_error"
return Action(
action_type=action_type,
parameters=params,
confidence=confidence,
reasoning=response
)
def generate_action(self, observation: Observation) -> Action:
prompt = self.build_prompt(observation)
mock_response = '[STEP] step=1 action=analyze_error params={} confidence=0.4 reward=0.0 done=false error=null'
return self.parse_action(mock_response)
def run_episode(self, task_id=None) -> dict:
obs = self.env.reset(task_id)
done = False
step = 0
total_reward = 0.0
last_reward = 0.0
while not done and step < self.env.max_steps:
action = self.generate_action(obs)
obs, reward, done, info = self.env.step(action)
last_reward = reward
total_reward += reward
step += 1
task = self.env.current_task
correct = (task or {}).get("correct_yaml", "").strip()
current_yaml = self.env.current_observation.pipeline_yaml.strip()
# Ground-truth success: fixed YAML matches reference (not sum of step rewards).
yaml_fixed = bool(correct) and (current_yaml == correct)
success = yaml_fixed
return {
"task_id": task["id"] if task else None,
"success": success,
"success_yaml_match": yaml_fixed,
"steps": step,
"score": total_reward,
"mean_step_reward": total_reward / max(step, 1),
"last_step_reward": last_reward,
"history": self.env.episode_history
}
def counterfactual_replay(self, episode: dict) -> list[dict]:
replays = []
for step in episode["history"]:
alt_action = Action("analyze_error", {}, 0.9, "Replay test")
replays.append({"original": step, "alternate": alt_action})
return replays
def generate_diff_explanation(self, before_yaml: str, after_yaml: str, error_msg: str) -> str:
explanation = f"Error was: {error_msg}\n"
explanation += "Change: Edited YAML to fix the failure.\n"
explanation += "Why it works: Addressed syntax error or missing dependency.\n"
return explanation