Spaces:
Sleeping
Sleeping
File size: 4,609 Bytes
3a26e23 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import os
import json
import textwrap
from typing import List
from openai import OpenAI
from environment import CodeReviewEnv
from models import Action, Observation
API_BASE_URL = os.getenv("API_BASE_URL")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME")
MAX_STEPS = 5
TEMPERATURE = 0.2
MAX_TOKENS = 200
FALLBACK_ACTION = "skip"
SYSTEM_PROMPT = textwrap.dedent(
"""
You are an AI code reviewer. Your task is to provide helpful comments on pull requests.
You will see a code snippet and existing comments.
Reply with ONE of the following:
- "write_comment: [your comment]" - to provide a helpful code review comment
- "skip" - if you cannot provide a helpful comment
- "done" - if the code is already perfect
Be constructive, specific, and focus on improving code quality.
"""
).strip()
def build_user_prompt(step: int, obs: Observation, history: List[str]) -> str:
newline = "\n"
comments_str = newline.join(obs.comments) if obs.comments else "No existing comments"
history_str = newline.join(history[-3:]) if history else "None"
prompt = textwrap.dedent(
f"""
Step: {step}
Code to review:
{obs.pr_code}
Existing comments on this PR:
{comments_str}
Previous actions:
{history_str}
Please provide your response (write_comment, skip, or done):
"""
).strip()
return prompt
def parse_model_action(response_text: str) -> Action:
if not response_text:
return Action(action_type=FALLBACK_ACTION)
raw_text = response_text.strip()
lower_text = raw_text.lower()
if lower_text.startswith("skip"):
return Action(action_type="skip")
if lower_text.startswith("done"):
return Action(action_type="done")
if lower_text.startswith("write_comment"):
if ":" in raw_text:
comment = raw_text.split(":", 1)[1].strip()
else:
comment = raw_text[len("write_comment"):].strip()
if not comment:
return Action(action_type="skip")
return Action(action_type="write_comment", comment_text=comment)
# default: treat as a comment
return Action(action_type="write_comment", comment_text=raw_text)
def main() -> None:
if not API_BASE_URL or not API_KEY or not MODEL_NAME:
print("Error: API_BASE_URL, HF_TOKEN/API_KEY, and MODEL_NAME must be set.")
return
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
env = CodeReviewEnv()
tasks = ["easy", "medium", "hard"]
scores = {}
print("=" * 50)
print("Code Review Environment - Baseline Inference")
print(f"API Base URL: {API_BASE_URL}")
print(f"Model: {MODEL_NAME}")
print("=" * 50)
for task in tasks:
print(f"\nRunning task: {task.upper()}")
env.set_task(task) # set task before reset
obs = env.reset()
history: List[str] = []
done = False
step = 0
final_reward = 0.0
while not done and step < MAX_STEPS:
step += 1
user_prompt = build_user_prompt(step, obs, history)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
)
response_text = completion.choices[0].message.content or ""
except Exception as exc:
print(f" Request failed: {exc}. Using fallback.")
response_text = FALLBACK_ACTION
action = parse_model_action(response_text)
obs, reward, done, info = env.step(action)
final_reward = reward.value
history.append(f"Step {step}: {action.action_type}")
print(f" Step {step} | Action: {action.action_type} | Reward: {reward.value:.2f}")
scores[task] = final_reward
print(f"{task.upper()} completed. Final Score: {final_reward:.2f}")
print("\n" + "="*50)
print("FINAL RESULTS")
print(json.dumps(scores, indent=2))
print("="*50)
if __name__ == "__main__":
main() |