File size: 3,795 Bytes
bc1b1a6
fcbfa5f
 
 
 
 
 
 
 
8003121
 
3ad52bb
 
 
8003121
b10aba7
8003121
3ad52bb
 
 
8003121
 
3ad52bb
 
 
 
7f5056e
 
8003121
 
 
3ad52bb
7f5056e
 
 
 
8003121
7f5056e
8743366
8003121
8743366
8003121
8743366
7f5056e
8003121
8743366
7f5056e
8743366
 
7f5056e
8743366
7f5056e
8743366
7f5056e
3ad52bb
 
 
 
 
8003121
114e5cf
8743366
04362f9
8003121
 
 
 
7f5056e
8003121
 
114e5cf
8003121
7f5056e
 
f90a861
8003121
 
7f5056e
3ad52bb
 
 
 
 
 
 
 
 
 
 
 
 
 
8743366
8003121
7f5056e
114e5cf
8743366
7f5056e
 
b10aba7
 
8003121
04362f9
 
 
 
 
 
7f5056e
 
8743366
8003121
114e5cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import sys

# Redirect all print calls from imported modules to stderr
_original_print = print
def print(*args, **kwargs):
    kwargs.setdefault('file', sys.stderr)
    _original_print(*args, **kwargs)

import os
import textwrap

API_BASE_URL = os.getenv("API_BASE_URL", "https://dummy.api")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy-key")
MODEL_NAME = os.getenv("MODEL_NAME", "dummy-model")
MAX_STEPS = 5
FALLBACK_ACTION = "skip"

from environment import CodeReviewEnv
from models import Action

SYSTEM_PROMPT = textwrap.dedent(
    """
    You are an AI code reviewer. Reply with one of:
    - write_comment: [comment]
    - ask_question: [question]
    - propose_fix: [code]
    - skip
    - done
    """
).strip()

def build_user_prompt(step, obs, history):
    return f"Step {step}\nCode:\n{obs.code_snippet}\nComments:\n{obs.comments}\nHistory:\n{history}\nYour response:"

def parse_model_action(text):
    if not text:
        return Action(action_type=FALLBACK_ACTION)
    lower = text.strip().lower()
    if lower.startswith("skip"):
        return Action(action_type="skip")
    if lower.startswith("done"):
        return Action(action_type="done")
    if lower.startswith("write_comment"):
        comment = text.split(":", 1)[1].strip() if ":" in text else text[14:].strip()
        return Action(action_type="write_comment", comment_text=comment)
    if lower.startswith("ask_question"):
        question = text.split(":", 1)[1].strip() if ":" in text else text[12:].strip()
        return Action(action_type="ask_question", question=question)
    if lower.startswith("propose_fix"):
        fix = text.split(":", 1)[1].strip() if ":" in text else text[11:].strip()
        return Action(action_type="propose_fix", fix_code=fix)
    return Action(action_type="write_comment", comment_text=text)

def main():
    try:
        from openai import OpenAI
        client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if API_BASE_URL != "https://dummy.api" else None
    except Exception:
        client = None

    env = CodeReviewEnv()
    tasks = ["easy", "medium", "hard", "harder", "hardest"]
    EPS = 0.001

    for task in tasks:
        env.set_task(task)
        obs = env.reset()
        history = []
        done = False
        step = 0
        final_reward = 0.0

        sys.stdout.write(f"[START] task={task}\n")
        sys.stdout.flush()

        while not done and step < MAX_STEPS:
            step += 1
            prompt = build_user_prompt(step, obs, history)
            messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}]

            response_text = FALLBACK_ACTION
            if client is not None:
                try:
                    resp = client.chat.completions.create(
                        model=MODEL_NAME,
                        messages=messages,
                        temperature=0.2,
                        max_tokens=200,
                    )
                    response_text = resp.choices[0].message.content or FALLBACK_ACTION
                except Exception:
                    pass

            action = parse_model_action(response_text)
            obs, reward, done, _ = env.step(action)
            final_reward = reward.value

            sys.stdout.write(f"[STEP] step={step} reward={final_reward:.3f}\n")
            sys.stdout.flush()

            history.append(f"Step {step}: {action.action_type}")

        # Clamp the final reward to be strictly between 0 and 1
        if final_reward <= 0.0:
            final_reward = EPS
        elif final_reward >= 1.0:
            final_reward = 1.0 - EPS

        sys.stdout.write(f"[END] task={task} score={final_reward:.3f} steps={step}\n")
        sys.stdout.flush()

if __name__ == "__main__":
    main()