File size: 3,375 Bytes
6e7ce30
 
 
 
 
 
 
d934356
6e7ce30
 
 
 
 
 
d934356
6e7ce30
 
 
 
d934356
6e7ce30
d934356
6e7ce30
 
 
 
a089c46
d934356
6e7ce30
 
 
 
d934356
6e7ce30
 
 
d934356
 
6e7ce30
 
 
 
 
 
 
d934356
 
 
 
6e7ce30
d934356
6e7ce30
 
 
 
 
 
d934356
6e7ce30
d934356
 
6e7ce30
d934356
 
 
 
 
 
6e7ce30
 
 
 
 
d934356
6e7ce30
d934356
 
6e7ce30
 
 
d934356
 
6e7ce30
d934356
6e7ce30
d934356
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import json
import logging
import requests
from openai import OpenAI
from environment.models import Action, Issue

# Better logging instead of quiet failures
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME", "llama3-70b-8192")
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")  # Set this for HF Spaces

client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)

def parse_llm_response(text: str) -> Action:
    """Parse LLM output into an Action. Expects JSON list of issues."""
    try:
        # Extract JSON from markdown blocks
        if "```json" in text:
            text = text.split("```json")[1].split("```")[0]
        elif "```" in text:
            text = text.split("```")[1].split("```")[0]
        
        data = json.loads(text.strip())
        issues = [Issue(**item) for item in data]
        return Action(issues=issues, final=True)
    except Exception as e:
        logger.error(f"Failed to parse LLM response: {e}")
        # Return an empty list indicating the model failed to find issues properly
        return Action(issues=[], final=True)

def run_task(task_id: str) -> float:
    # 1. Reset environment to get initial observation and session_id
    logger.info(f"Running task: {task_id}")
    resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
    resp.raise_for_status()
    reset_data = resp.json()
    
    session_id = reset_data["session_id"]
    obs = reset_data["observation"]
    
    # 2. Build prompt using the code from the observation
    prompt = f"""You are a code reviewer. Analyze the following Python code and list all issues (bugs, style, security, performance, documentation). 
    Return a JSON list where each item has: "line" (int), "category" (one of: bug, style, security, performance, documentation), "description" (string). 
    Example: [{{"line": 5, "category": "bug", "description": "Division by zero"}}]

Code:
{obs['code']}
"""
    try:
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0  # Reproducibility
        )
        raw = response.choices[0].message.content
        logger.debug(f"Raw Output: {raw}")
    except Exception as e:
        logger.error(f"OpenAI completion error: {e}")
        raw = "[]"
        
    action = parse_llm_response(raw)
    
    # 3. Take step using the session_id
    step_resp = requests.post(f"{ENV_URL}/step", json={
        "session_id": session_id,
        "action": action.dict()
    })
    step_resp.raise_for_status()
    data = step_resp.json()
    
    final_reward = data["reward"]["value"]
    logger.info(f"Task {task_id}: Final Score = {final_reward:.3f}")
    return final_reward

if __name__ == "__main__":
    scores = {}
    for task in ["easy", "medium", "hard"]:
        try:
            scores[task] = run_task(task)
        except Exception as e:
            logger.error(f"Task execution failed ({task}): {e}")
            scores[task] = 0.0
            
    print("\n=== Baseline Results ===")
    for task, score in scores.items():
        print(f"{task}: {score:.3f}")