SyamSashank commited on
Commit
d934356
·
verified ·
1 Parent(s): effb609

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +34 -65
inference.py CHANGED
@@ -5,51 +5,37 @@ import requests
5
  from openai import OpenAI
6
  from environment.models import Action, Issue
7
 
8
- # Configure logging for better visibility in Hugging Face Logs
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- # --- CONFIGURATION ---
13
- # The judges will provide these via environment variables
14
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
15
  API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
16
  MODEL_NAME = os.getenv("MODEL_NAME", "llama3-70b-8192")
 
17
 
18
- # UPDATED: Points directly to your Space URL by default
19
- ENV_URL = os.getenv("ENV_URL", "https://syam-sashank-codereview-env.hf.space")
20
-
21
- # Initialize OpenAI Client
22
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
23
 
24
  def parse_llm_response(text: str) -> Action:
25
- """
26
- Parses the LLM's string output into a structured Action object.
27
- Handles Markdown code blocks commonly used by LLMs.
28
- """
29
  try:
30
- # Clean up Markdown JSON blocks
31
  if "```json" in text:
32
  text = text.split("```json")[1].split("```")[0]
33
  elif "```" in text:
34
  text = text.split("```")[1].split("```")[0]
35
-
36
- data = json.loads(text.strip())
37
 
38
- # Validate items against the Issue model
39
  issues = [Issue(**item) for item in data]
40
  return Action(issues=issues, final=True)
41
  except Exception as e:
42
  logger.error(f"Failed to parse LLM response: {e}")
43
- # Return empty list so the grader can still run (and likely give 0.0)
44
  return Action(issues=[], final=True)
45
 
46
  def run_task(task_id: str) -> float:
47
- """
48
- Executes a single task: Reset -> LLM Inference -> Step -> Return Reward.
49
- """
50
- logger.info(f"--- Starting Task: {task_id} ---")
51
-
52
- # 1. Reset environment
53
  resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
54
  resp.raise_for_status()
55
  reset_data = resp.json()
@@ -57,66 +43,49 @@ def run_task(task_id: str) -> float:
57
  session_id = reset_data["session_id"]
58
  obs = reset_data["observation"]
59
 
60
- # 2. Build the prompt
61
- prompt = f"""You are a professional security and code reviewer.
62
- Analyze the following Python code and identify all bugs, style issues, security flaws, performance anti-patterns, and missing documentation.
63
-
64
- Return ONLY a JSON list where each item has:
65
- - "line": (integer)
66
- - "category": (one of: bug, style, security, performance, documentation)
67
- - "description": (string, max 200 chars)
68
 
69
- Code to review:
70
  {obs['code']}
71
  """
72
-
73
  try:
74
  response = client.chat.completions.create(
75
  model=MODEL_NAME,
76
  messages=[{"role": "user", "content": prompt}],
77
- temperature=0.0 # Crucial for reproducible baseline scores
78
  )
79
- raw_content = response.choices[0].message.content
 
80
  except Exception as e:
81
- logger.error(f"LLM Completion error: {e}")
82
- raw_content = "[]"
83
-
84
- # Convert LLM text to Action object
85
- action = parse_llm_response(raw_content)
86
-
87
- # 3. Take step in the environment
88
  step_resp = requests.post(f"{ENV_URL}/step", json={
89
  "session_id": session_id,
90
  "action": action.dict()
91
  })
92
  step_resp.raise_for_status()
93
- result_data = step_resp.json()
94
-
95
- # Extract the F1-based reward
96
- final_reward = result_data["reward"]["value"]
97
- logger.info(f"Result for {task_id}: Score = {final_reward:.3f}")
98
 
 
 
99
  return final_reward
100
 
101
  if __name__ == "__main__":
102
- # The competition requires scores for at least 3 tasks
103
- task_list = ["easy", "medium", "hard"]
104
- final_scores = {}
105
-
106
- print(f"Connecting to environment at: {ENV_URL}")
107
-
108
- for task in task_list:
109
  try:
110
- score = run_task(task)
111
- final_scores[task] = score
112
  except Exception as e:
113
- logger.error(f"Task {task} failed to execute: {e}")
114
- final_scores[task] = 0.0
115
-
116
- # Final Summary for the Logs
117
- print("\n" + "="*30)
118
- print(" BASELINE PERFORMANCE REPORT ")
119
- print("="*30)
120
- for task, score in final_scores.items():
121
- print(f"Task: {task:8} | Score: {score:.3f}")
122
- print("="*30)
 
5
  from openai import OpenAI
6
  from environment.models import Action, Issue
7
 
8
+ # Better logging instead of quiet failures
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
 
 
12
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
13
  API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
14
  MODEL_NAME = os.getenv("MODEL_NAME", "llama3-70b-8192")
15
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:7860") # Set this for HF Spaces
16
 
 
 
 
 
17
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
18
 
19
  def parse_llm_response(text: str) -> Action:
20
+ """Parse LLM output into an Action. Expects JSON list of issues."""
 
 
 
21
  try:
22
+ # Extract JSON from markdown blocks
23
  if "```json" in text:
24
  text = text.split("```json")[1].split("```")[0]
25
  elif "```" in text:
26
  text = text.split("```")[1].split("```")[0]
 
 
27
 
28
+ data = json.loads(text.strip())
29
  issues = [Issue(**item) for item in data]
30
  return Action(issues=issues, final=True)
31
  except Exception as e:
32
  logger.error(f"Failed to parse LLM response: {e}")
33
+ # Return an empty list indicating the model failed to find issues properly
34
  return Action(issues=[], final=True)
35
 
36
  def run_task(task_id: str) -> float:
37
+ # 1. Reset environment to get initial observation and session_id
38
+ logger.info(f"Running task: {task_id}")
 
 
 
 
39
  resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
40
  resp.raise_for_status()
41
  reset_data = resp.json()
 
43
  session_id = reset_data["session_id"]
44
  obs = reset_data["observation"]
45
 
46
+ # 2. Build prompt using the code from the observation
47
+ prompt = f"""You are a code reviewer. Analyze the following Python code and list all issues (bugs, style, security, performance, documentation).
48
+ Return a JSON list where each item has: "line" (int), "category" (one of: bug, style, security, performance, documentation), "description" (string).
49
+ Example: [{{"line": 5, "category": "bug", "description": "Division by zero"}}]
 
 
 
 
50
 
51
+ Code:
52
  {obs['code']}
53
  """
 
54
  try:
55
  response = client.chat.completions.create(
56
  model=MODEL_NAME,
57
  messages=[{"role": "user", "content": prompt}],
58
+ temperature=0.0 # Reproducibility
59
  )
60
+ raw = response.choices[0].message.content
61
+ logger.debug(f"Raw Output: {raw}")
62
  except Exception as e:
63
+ logger.error(f"OpenAI completion error: {e}")
64
+ raw = "[]"
65
+
66
+ action = parse_llm_response(raw)
67
+
68
+ # 3. Take step using the session_id
 
69
  step_resp = requests.post(f"{ENV_URL}/step", json={
70
  "session_id": session_id,
71
  "action": action.dict()
72
  })
73
  step_resp.raise_for_status()
74
+ data = step_resp.json()
 
 
 
 
75
 
76
+ final_reward = data["reward"]["value"]
77
+ logger.info(f"Task {task_id}: Final Score = {final_reward:.3f}")
78
  return final_reward
79
 
80
  if __name__ == "__main__":
81
+ scores = {}
82
+ for task in ["easy", "medium", "hard"]:
 
 
 
 
 
83
  try:
84
+ scores[task] = run_task(task)
 
85
  except Exception as e:
86
+ logger.error(f"Task execution failed ({task}): {e}")
87
+ scores[task] = 0.0
88
+
89
+ print("\n=== Baseline Results ===")
90
+ for task, score in scores.items():
91
+ print(f"{task}: {score:.3f}")