code-review-environment / inference.py
Shardul Dhekane
Normalize provider-agnostic API config and add Space server endpoints
edaad73
#!/usr/bin/env python3
from dotenv import load_dotenv
load_dotenv()
import os
import json
import argparse
import sys
from typing import Dict, Any
from openai import OpenAI
def resolve_api_key() -> str:
# Canonical env var is API_KEY, aliases are supported for compatibility.
return (
(os.environ.get("API_KEY") or "").strip()
or (os.environ.get("HF_TOKEN") or "").strip()
or (os.environ.get("OPENAI_API_KEY") or "").strip()
)
API_BASE_URL = os.environ.get("API_BASE_URL", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "")
API_KEY = resolve_api_key()
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.7"))
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2000"))
REQUEST_TIMEOUT = int(os.environ.get("REQUEST_TIMEOUT", "60"))
if not API_BASE_URL:
print("=" * 60)
print("API Configuration Required")
print("=" * 60)
print("\nPlease set the following environment variables:\n")
print(" API_BASE_URL - OpenAI-compatible API endpoint")
print(" MODEL_NAME - Model identifier")
print(" API_KEY - API key (canonical)\n")
print("Supported auth aliases (backward compatibility):")
print(" HF_TOKEN")
print(" OPENAI_API_KEY\n")
print("Examples:\n")
print(" OpenAI:")
print(" export API_BASE_URL=https://api.openai.com/v1")
print(" export MODEL_NAME=gpt-4o-mini")
print(" export API_KEY=sk-xxxxx\n")
print(" Groq:")
print(" export API_BASE_URL=https://api.groq.com/openai/v1")
print(" export MODEL_NAME=llama-3.3-70b-versatile")
print(" export API_KEY=gsk_xxxxx\n")
print(" Local Ollama:")
print(" export API_BASE_URL=http://localhost:11434/v1")
print(" export MODEL_NAME=llama3")
print(" export API_KEY=not-needed\n")
print("=" * 60)
sys.exit(1)
if not MODEL_NAME:
print("ERROR: MODEL_NAME environment variable is required")
sys.exit(1)
if not API_KEY:
print("ERROR: Missing auth token. Set API_KEY (preferred), or HF_TOKEN/OPENAI_API_KEY.")
sys.exit(1)
FALLBACK_ACTION = json.dumps({
"action_type": "request_changes",
"comments": [],
"suggestions": [],
"final_decision": "changes_requested"
})
def add_line_numbers(code: str) -> str:
lines = code.split("\n")
return "\n".join(f"{i+1}: {line}" for i, line in enumerate(lines))
class LLMClient:
def __init__(self, base_url: str, api_key: str, model: str):
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.model = model
self.client = OpenAI(
base_url=self.base_url,
api_key=self.api_key,
timeout=REQUEST_TIMEOUT
)
print("Connected using OpenAI client")
print(f"Endpoint: {self.base_url}")
print(f"Model: {self.model}\n")
def chat_completion(self, messages: list, temperature: float = 0.7, max_tokens: int = 2000) -> str:
completion = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
)
return completion.choices[0].message.content or ""
class CodeReviewAgent:
def __init__(self):
self.client = LLMClient(API_BASE_URL, API_KEY, MODEL_NAME)
self.history = []
self.phase = 1
def get_action(self, observation: Dict[str, Any]) -> str:
system_prompt = """You are an expert code reviewer. You MUST follow this exact sequence:
PHASE 1 - Add Comments: Use action_type "add_comment" to identify ALL bugs with exact line numbers
PHASE 2 - Suggest Fixes: Use action_type "suggest_fix" to provide fixes for every bug found
PHASE 3 - Final Decision: Use action_type "request_changes" with final_decision "changes_requested"
RULES:
- NEVER skip straight to approve or request_changes without first adding comments and suggestions
- NEVER combine phases - each action should do ONE thing
- ALWAYS use the exact line numbers shown in the code diff
- ALWAYS set severity for comments: "critical", "high", "medium", or "low"
- If no bugs found in Phase 1, skip to Phase 3 with "approved"
Respond ONLY with a valid JSON object, no extra text:
{
"action_type": "add_comment" | "suggest_fix" | "approve" | "request_changes",
"comments": [
{
"line_number": <exact line number>,
"content": "Detailed explanation of the bug",
"is_issue": true,
"severity": "critical" | "high" | "medium" | "low"
}
],
"suggestions": [
{
"original_line": <exact line number>,
"suggested_code": "corrected code here",
"explanation": "why this fix works"
}
],
"final_decision": "approved" | "changes_requested"
}"""
prev_comments = observation.get('previous_comments', [])
prev_suggestions = observation.get('previous_suggestions', [])
comments_text = "\n".join([
f" Line {c.get('line_number') if isinstance(c, dict) else c.line_number}: "
f"{c.get('content') if isinstance(c, dict) else c.content}"
for c in prev_comments
]) or "None yet"
suggestions_text = "\n".join([
f" Line {s.get('original_line') if isinstance(s, dict) else s.original_line}: "
f"{s.get('suggested_code') if isinstance(s, dict) else s.suggested_code}"
for s in prev_suggestions
]) or "None yet"
if self.phase == 1:
phase_instruction = """
YOUR TASK NOW (Phase 1 - Add Comments):
- action_type MUST be "add_comment"
- Carefully read the code diff line by line
- Find ALL bugs, vulnerabilities, or issues
- Comment on each one with the EXACT line number shown
- Do NOT make a final decision yet
- Do NOT suggest fixes yet
"""
elif self.phase == 2:
phase_instruction = """
YOUR TASK NOW (Phase 2 - Suggest Fixes):
- action_type MUST be "suggest_fix"
- For every bug you commented on, provide a concrete code fix
- Use the same line numbers as your comments
- Do NOT make a final decision yet
"""
else:
phase_instruction = """
YOUR TASK NOW (Phase 3 - Final Decision):
- action_type MUST be "request_changes"
- Set final_decision to "changes_requested"
- No new comments or suggestions needed
"""
user_prompt = f"""
Code Review Task:
{observation.get('task_description', 'Review the following code changes')}
Code Diff (USE THESE EXACT LINE NUMBERS in your response):
{add_line_numbers(observation.get('code_diff', ''))}
File Context:
{observation.get('file_context', '')}
Current Step: {observation.get('current_step', 0)}/{observation.get('max_steps', 50)}
Comments already made:
{comments_text}
Suggestions already made:
{suggestions_text}
{phase_instruction}
Respond with JSON only.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
try:
response = self.client.chat_completion(messages, TEMPERATURE, MAX_TOKENS)
response = response.strip()
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
action_data = json.loads(response.strip())
if "action_type" not in action_data:
action_data["action_type"] = "request_changes"
if "comments" not in action_data:
action_data["comments"] = []
if "suggestions" not in action_data:
action_data["suggestions"] = []
self.phase += 1
return json.dumps(action_data)
except json.JSONDecodeError as e:
print(f"Failed to parse JSON response: {e}")
print(f"Raw response: {response[:200]}...")
self.phase += 1
return FALLBACK_ACTION
except Exception as e:
print(f"Error getting action from LLM: {e}")
return FALLBACK_ACTION
def validate_action(self, action: Dict, observation: Dict) -> Dict:
line_count = observation.get('line_count', 999)
for comment in action.get("comments", []):
comment["line_number"] = max(1, min(comment.get("line_number", 1), line_count))
if not comment.get("severity"):
comment["severity"] = "medium"
if "is_issue" not in comment:
comment["is_issue"] = True
for suggestion in action.get("suggestions", []):
suggestion["original_line"] = max(1, min(suggestion.get("original_line", 1), line_count))
return action
def parse_action(self, action_str: str) -> Dict[str, Any]:
try:
return json.loads(action_str)
except json.JSONDecodeError:
return {"action_type": "request_changes", "comments": [], "suggestions": []}
def main():
sys.path.append('.')
try:
from environment.env import CodeReviewEnv
except ImportError as e:
print(f"Failed to import environment: {e}")
print("Make sure you're in the correct directory and environment is installed.")
sys.exit(1)
parser = argparse.ArgumentParser(description="Run code review agent")
parser.add_argument("--task-id", type=str, default="bug_detection_easy_1")
parser.add_argument("--max-steps", type=int, default=50)
parser.add_argument("--output", type=str, default="baseline_results.json")
args = parser.parse_args()
print("=" * 60)
print("Code Review Agent")
print("=" * 60)
env = CodeReviewEnv()
env.max_steps = args.max_steps
agent = CodeReviewAgent()
obs = env.reset(task_id=args.task_id)
done = False
step = 0
total_reward = 0.0
print(f"\nTask : {args.task_id}")
print(f"Desc : {obs.get('task_description', 'N/A')}")
print(f"Model : {MODEL_NAME}")
print("-" * 60)
while not done and step < args.max_steps:
action_str = agent.get_action(obs)
action = agent.parse_action(action_str)
action = agent.validate_action(action, obs)
obs, reward, done, info = env.step(action)
total_reward += reward
step += 1
print(f"\nStep {step}/{args.max_steps}:")
print(f" Phase : {agent.phase - 1}")
print(f" Action : {action.get('action_type')}")
print(f" Comments : {len(action.get('comments', []))}")
print(f" Suggestions : {len(action.get('suggestions', []))}")
print(f" Reward : {reward:.3f}")
print(f" Total : {total_reward:.3f}")
print(f" Score : {info.get('task_score', 0):.3f}")
if info.get('last_action_valid') is False:
print(f" Warning : {info.get('error', 'Invalid action')}")
final_score = env.get_task_score()
print("\n" + "=" * 60)
print("Final Results:")
print(f" Task : {args.task_id}")
print(f" Total Reward : {total_reward:.3f}")
print(f" Task Score : {final_score:.3f}/1.0")
print(f" Steps : {step}")
print("=" * 60)
env.close()
results = {
"task_id": args.task_id,
"total_reward": round(total_reward, 4),
"task_score": round(final_score, 4),
"steps": step,
"max_steps": args.max_steps,
"provider": "openai-client",
"model": MODEL_NAME,
"api_base_url": API_BASE_URL
}
with open(args.output, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {args.output}")
if __name__ == "__main__":
main()