SyamSashank commited on
Commit
a089c46
·
verified ·
1 Parent(s): 3e0331f

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +65 -34
inference.py CHANGED
@@ -5,37 +5,51 @@ import requests
5
  from openai import OpenAI
6
  from environment.models import Action, Issue
7
 
8
- # Better logging instead of quiet failures
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
 
 
12
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
13
  API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
14
  MODEL_NAME = os.getenv("MODEL_NAME", "llama3-70b-8192")
15
- ENV_URL = os.getenv("ENV_URL", "http://localhost:7860") # Set this for HF Spaces
16
 
 
 
 
 
17
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
18
 
19
  def parse_llm_response(text: str) -> Action:
20
- """Parse LLM output into an Action. Expects JSON list of issues."""
 
 
 
21
  try:
22
- # Extract JSON from markdown blocks
23
  if "```json" in text:
24
  text = text.split("```json")[1].split("```")[0]
25
  elif "```" in text:
26
  text = text.split("```")[1].split("```")[0]
27
-
28
  data = json.loads(text.strip())
 
 
29
  issues = [Issue(**item) for item in data]
30
  return Action(issues=issues, final=True)
31
  except Exception as e:
32
  logger.error(f"Failed to parse LLM response: {e}")
33
- # Return an empty list indicating the model failed to find issues properly
34
  return Action(issues=[], final=True)
35
 
36
  def run_task(task_id: str) -> float:
37
- # 1. Reset environment to get initial observation and session_id
38
- logger.info(f"Running task: {task_id}")
 
 
 
 
39
  resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
40
  resp.raise_for_status()
41
  reset_data = resp.json()
@@ -43,49 +57,66 @@ def run_task(task_id: str) -> float:
43
  session_id = reset_data["session_id"]
44
  obs = reset_data["observation"]
45
 
46
- # 2. Build prompt using the code from the observation
47
- prompt = f"""You are a code reviewer. Analyze the following Python code and list all issues (bugs, style, security, performance, documentation).
48
- Return a JSON list where each item has: "line" (int), "category" (one of: bug, style, security, performance, documentation), "description" (string).
49
- Example: [{{"line": 5, "category": "bug", "description": "Division by zero"}}]
 
 
 
 
50
 
51
- Code:
52
  {obs['code']}
53
  """
 
54
  try:
55
  response = client.chat.completions.create(
56
  model=MODEL_NAME,
57
  messages=[{"role": "user", "content": prompt}],
58
- temperature=0.0 # Reproducibility
59
  )
60
- raw = response.choices[0].message.content
61
- logger.debug(f"Raw Output: {raw}")
62
  except Exception as e:
63
- logger.error(f"OpenAI completion error: {e}")
64
- raw = "[]"
65
-
66
- action = parse_llm_response(raw)
67
-
68
- # 3. Take step using the session_id
 
69
  step_resp = requests.post(f"{ENV_URL}/step", json={
70
  "session_id": session_id,
71
  "action": action.dict()
72
  })
73
  step_resp.raise_for_status()
74
- data = step_resp.json()
 
 
 
 
75
 
76
- final_reward = data["reward"]["value"]
77
- logger.info(f"Task {task_id}: Final Score = {final_reward:.3f}")
78
  return final_reward
79
 
80
  if __name__ == "__main__":
81
- scores = {}
82
- for task in ["easy", "medium", "hard"]:
 
 
 
 
 
83
  try:
84
- scores[task] = run_task(task)
 
85
  except Exception as e:
86
- logger.error(f"Task execution failed ({task}): {e}")
87
- scores[task] = 0.0
88
-
89
- print("\n=== Baseline Results ===")
90
- for task, score in scores.items():
91
- print(f"{task}: {score:.3f}")
 
 
 
 
 
5
  from openai import OpenAI
6
  from environment.models import Action, Issue
7
 
8
+ # Configure logging for better visibility in Hugging Face Logs
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ # --- CONFIGURATION ---
13
+ # The judges will provide these via environment variables
14
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
15
  API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
16
  MODEL_NAME = os.getenv("MODEL_NAME", "llama3-70b-8192")
 
17
 
18
+ # UPDATED: Points directly to your Space URL by default
19
+ ENV_URL = os.getenv("ENV_URL", "https://syam-sashank-codereview-env.hf.space")
20
+
21
+ # Initialize OpenAI Client
22
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
23
 
24
  def parse_llm_response(text: str) -> Action:
25
+ """
26
+ Parses the LLM's string output into a structured Action object.
27
+ Handles Markdown code blocks commonly used by LLMs.
28
+ """
29
  try:
30
+ # Clean up Markdown JSON blocks
31
  if "```json" in text:
32
  text = text.split("```json")[1].split("```")[0]
33
  elif "```" in text:
34
  text = text.split("```")[1].split("```")[0]
35
+
36
  data = json.loads(text.strip())
37
+
38
+ # Validate items against the Issue model
39
  issues = [Issue(**item) for item in data]
40
  return Action(issues=issues, final=True)
41
  except Exception as e:
42
  logger.error(f"Failed to parse LLM response: {e}")
43
+ # Return empty list so the grader can still run (and likely give 0.0)
44
  return Action(issues=[], final=True)
45
 
46
  def run_task(task_id: str) -> float:
47
+ """
48
+ Executes a single task: Reset -> LLM Inference -> Step -> Return Reward.
49
+ """
50
+ logger.info(f"--- Starting Task: {task_id} ---")
51
+
52
+ # 1. Reset environment
53
  resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
54
  resp.raise_for_status()
55
  reset_data = resp.json()
 
57
  session_id = reset_data["session_id"]
58
  obs = reset_data["observation"]
59
 
60
+ # 2. Build the prompt
61
+ prompt = f"""You are a professional security and code reviewer.
62
+ Analyze the following Python code and identify all bugs, style issues, security flaws, performance anti-patterns, and missing documentation.
63
+
64
+ Return ONLY a JSON list where each item has:
65
+ - "line": (integer)
66
+ - "category": (one of: bug, style, security, performance, documentation)
67
+ - "description": (string, max 200 chars)
68
 
69
+ Code to review:
70
  {obs['code']}
71
  """
72
+
73
  try:
74
  response = client.chat.completions.create(
75
  model=MODEL_NAME,
76
  messages=[{"role": "user", "content": prompt}],
77
+ temperature=0.0 # Crucial for reproducible baseline scores
78
  )
79
+ raw_content = response.choices[0].message.content
 
80
  except Exception as e:
81
+ logger.error(f"LLM Completion error: {e}")
82
+ raw_content = "[]"
83
+
84
+ # Convert LLM text to Action object
85
+ action = parse_llm_response(raw_content)
86
+
87
+ # 3. Take step in the environment
88
  step_resp = requests.post(f"{ENV_URL}/step", json={
89
  "session_id": session_id,
90
  "action": action.dict()
91
  })
92
  step_resp.raise_for_status()
93
+ result_data = step_resp.json()
94
+
95
+ # Extract the F1-based reward
96
+ final_reward = result_data["reward"]["value"]
97
+ logger.info(f"Result for {task_id}: Score = {final_reward:.3f}")
98
 
 
 
99
  return final_reward
100
 
101
  if __name__ == "__main__":
102
+ # The competition requires scores for at least 3 tasks
103
+ task_list = ["easy", "medium", "hard"]
104
+ final_scores = {}
105
+
106
+ print(f"Connecting to environment at: {ENV_URL}")
107
+
108
+ for task in task_list:
109
  try:
110
+ score = run_task(task)
111
+ final_scores[task] = score
112
  except Exception as e:
113
+ logger.error(f"Task {task} failed to execute: {e}")
114
+ final_scores[task] = 0.0
115
+
116
+ # Final Summary for the Logs
117
+ print("\n" + "="*30)
118
+ print(" BASELINE PERFORMANCE REPORT ")
119
+ print("="*30)
120
+ for task, score in final_scores.items():
121
+ print(f"Task: {task:8} | Score: {score:.3f}")
122
+ print("="*30)