Shardul Dhekane commited on
Commit
01cd7e9
·
1 Parent(s): 60fef21
Files changed (2) hide show
  1. Dockerfile +18 -22
  2. inference.py +228 -1
Dockerfile CHANGED
@@ -2,47 +2,43 @@ FROM python:3.10-slim
2
 
3
  WORKDIR /app
4
 
5
- # Install system dependencies
6
  RUN apt-get update && apt-get install -y \
7
  git \
8
  curl \
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
- # Copy requirements first for better caching
12
  COPY requirements.txt .
13
  RUN pip install --no-cache-dir -r requirements.txt
14
 
15
- # Copy the rest of the application
16
  COPY . .
17
 
18
- # Create a wrapper script to handle API configuration
19
  RUN echo '#!/bin/bash\n\
20
- # Check if API configuration is provided\n\
21
  if [ -z "$API_BASE_URL" ]; then\n\
22
- echo " WARNING: API_BASE_URL not set!"\n\
23
- echo " Please set the following environment variables:"\n\
24
- echo " API_BASE_URL - Your API endpoint"\n\
25
- echo " MODEL_NAME - Model identifier"\n\
26
- echo " API_KEY - Your API key"\n\
27
  echo ""\n\
28
  echo "Examples:"\n\
29
- echo " OpenAI: API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4"\n\
30
- echo " Gemini: API_BASE_URL=https://generativelanguage.googleapis.com MODEL_NAME=gemini-1.5-pro"\n\
31
- echo " Local: API_BASE_URL=http://localhost:11434/v1 MODEL_NAME=llama2"\n\
32
- echo ""\n\
 
 
 
 
33
  exit 1\n\
34
  fi\n\
35
  \n\
36
- # Show current configuration\n\
37
- echo "🔧 Running with configuration:"\n\
38
- echo " API_BASE_URL: ${API_BASE_URL}"\n\
39
- echo " MODEL_NAME: ${MODEL_NAME}"\n\
40
- echo " TEMPERATURE: ${TEMPERATURE:-0.7}"\n\
41
- echo " MAX_TOKENS: ${MAX_TOKENS:-2000}"\n\
42
  echo ""\n\
43
  \n\
44
- # Run inference with provided configuration\n\
45
  python inference.py "$@"' > /usr/local/bin/run-agent && chmod +x /usr/local/bin/run-agent
46
 
47
- # Default command (will be overridden by docker run arguments)
48
  CMD ["run-agent", "--task-id", "bug_detection_easy"]
 
2
 
3
  WORKDIR /app
4
 
 
5
  RUN apt-get update && apt-get install -y \
6
  git \
7
  curl \
8
  && rm -rf /var/lib/apt/lists/*
9
 
 
10
  COPY requirements.txt .
11
  RUN pip install --no-cache-dir -r requirements.txt
12
 
 
13
  COPY . .
14
 
 
15
  RUN echo '#!/bin/bash\n\
 
16
  if [ -z "$API_BASE_URL" ]; then\n\
17
+ echo "ERROR: API_BASE_URL environment variable is required"\n\
18
+ echo "Please set:"\n\
19
+ echo " API_BASE_URL - Your API endpoint"\n\
20
+ echo " MODEL_NAME - Model identifier"\n\
21
+ echo " API_KEY - Your API key"\n\
22
  echo ""\n\
23
  echo "Examples:"\n\
24
+ echo " OpenAI: API_BASE_URL=https://api.openai.com/v1 MODEL_NAME=gpt-4"\n\
25
+ echo " Gemini: API_BASE_URL=https://generativelanguage.googleapis.com MODEL_NAME=gemini-1.5-pro"\n\
26
+ echo " Local: API_BASE_URL=http://localhost:11434/v1 MODEL_NAME=llama2"\n\
27
+ exit 1\n\
28
+ fi\n\
29
+ \n\
30
+ if [ -z "$MODEL_NAME" ]; then\n\
31
+ echo "ERROR: MODEL_NAME environment variable is required"\n\
32
  exit 1\n\
33
  fi\n\
34
  \n\
35
+ echo "Configuration:"\n\
36
+ echo " API_BASE_URL: ${API_BASE_URL}"\n\
37
+ echo " MODEL_NAME: ${MODEL_NAME}"\n\
38
+ echo " TEMPERATURE: ${TEMPERATURE:-0.7}"\n\
39
+ echo " MAX_TOKENS: ${MAX_TOKENS:-2000}"\n\
 
40
  echo ""\n\
41
  \n\
 
42
  python inference.py "$@"' > /usr/local/bin/run-agent && chmod +x /usr/local/bin/run-agent
43
 
 
44
  CMD ["run-agent", "--task-id", "bug_detection_easy"]
inference.py CHANGED
@@ -60,4 +60,231 @@ class LLMClient:
60
  return result["choices"][0]["message"]["content"]
61
  except Exception as e:
62
  print(f"OpenAI API error: {e}")
63
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  return result["choices"][0]["message"]["content"]
61
  except Exception as e:
62
  print(f"OpenAI API error: {e}")
63
+ raise
64
+
65
+ def _gemini_completion(self, messages: list, temperature: float, max_tokens: int) -> str:
66
+ gemini_prompt = self._convert_to_gemini_format(messages)
67
+
68
+ url = f"{self.base_url}/v1/models/{self.model}:generateContent"
69
+
70
+ headers = {
71
+ "Content-Type": "application/json",
72
+ "x-goog-api-key": self.api_key
73
+ }
74
+
75
+ payload = {
76
+ "contents": [{
77
+ "parts": [{"text": gemini_prompt}]
78
+ }],
79
+ "generationConfig": {
80
+ "temperature": temperature,
81
+ "maxOutputTokens": max_tokens
82
+ }
83
+ }
84
+
85
+ try:
86
+ response = requests.post(url, headers=headers, json=payload, timeout=30)
87
+ response.raise_for_status()
88
+ result = response.json()
89
+ return result["candidates"][0]["content"]["parts"][0]["text"]
90
+ except Exception as e:
91
+ print(f"Gemini API error: {e}")
92
+ raise
93
+
94
+ def _convert_to_gemini_format(self, messages: list) -> str:
95
+ prompt_parts = []
96
+ for msg in messages:
97
+ role = msg["role"]
98
+ content = msg["content"]
99
+
100
+ if role == "system":
101
+ prompt_parts.append(f"System: {content}")
102
+ elif role == "user":
103
+ prompt_parts.append(f"User: {content}")
104
+ elif role == "assistant":
105
+ prompt_parts.append(f"Assistant: {content}")
106
+
107
+ return "\n\n".join(prompt_parts)
108
+
109
+
110
+ class CodeReviewAgent:
111
+ def __init__(self):
112
+ self.client = LLMClient(API_BASE_URL, API_KEY, MODEL_NAME)
113
+ self.history = []
114
+
115
+ def get_action(self, observation: Dict[str, Any]) -> str:
116
+ system_prompt = """You are an expert code reviewer. Your task is to review code changes and provide feedback.
117
+
118
+ Review the code diff and identify issues. You can:
119
+ 1. ADD_COMMENT: Add a comment about an issue on a specific line
120
+ 2. SUGGEST_FIX: Suggest a specific code fix for an issue
121
+ 3. APPROVE: Approve the code changes (only if no critical issues)
122
+ 4. REQUEST_CHANGES: Request changes (if issues are found)
123
+
124
+ Respond with a JSON object in this format:
125
+ {
126
+ "action_type": "add_comment" | "suggest_fix" | "approve" | "request_changes",
127
+ "comments": [
128
+ {
129
+ "line_number": 10,
130
+ "content": "This line has a potential bug...",
131
+ "is_issue": true,
132
+ "severity": "high"
133
+ }
134
+ ],
135
+ "suggestions": [
136
+ {
137
+ "original_line": 10,
138
+ "suggested_code": "if x != 0:",
139
+ "explanation": "Prevents division by zero"
140
+ }
141
+ ],
142
+ "final_decision": "approved" | "changes_requested" (only if action_type is approve or request_changes)
143
+ }
144
+
145
+ Be thorough but concise. Focus on real issues like bugs, security vulnerabilities, performance problems, and code quality."""
146
+
147
+ user_prompt = f"""
148
+ Code Review Task:
149
+ {observation.get('task_description', 'Review the following code changes')}
150
+
151
+ Code Diff:
152
+ {observation.get('code_diff', '')}
153
+
154
+ File Context:
155
+ {observation.get('file_context', '')}
156
+
157
+
158
+ Current step: {observation.get('current_step', 0)}/{observation.get('max_steps', 50)}
159
+ Previous actions taken: {len(observation.get('previous_comments', []))} comments, {len(observation.get('previous_suggestions', []))} suggestions
160
+
161
+ Please provide your review action as JSON.
162
+ """
163
+
164
+ messages = [
165
+ {"role": "system", "content": system_prompt},
166
+ {"role": "user", "content": user_prompt}
167
+ ]
168
+
169
+ try:
170
+ response = self.client.chat_completion(messages, TEMPERATURE, MAX_TOKENS)
171
+
172
+ response = response.strip()
173
+ if response.startswith("```json"):
174
+ response = response[7:]
175
+ if response.startswith("```"):
176
+ response = response[3:]
177
+ if response.endswith("```"):
178
+ response = response[:-3]
179
+
180
+ action_data = json.loads(response.strip())
181
+
182
+ if "action_type" not in action_data:
183
+ action_data["action_type"] = "request_changes"
184
+ if "comments" not in action_data:
185
+ action_data["comments"] = []
186
+ if "suggestions" not in action_data:
187
+ action_data["suggestions"] = []
188
+
189
+ return json.dumps(action_data)
190
+
191
+ except Exception as e:
192
+ print(f"Error getting action from LLM: {e}")
193
+ return FALLBACK_ACTION
194
+
195
+ def parse_action(self, action_str: str) -> Dict[str, Any]:
196
+ try:
197
+ return json.loads(action_str)
198
+ except:
199
+ return {"action_type": "request_changes", "comments": [], "suggestions": []}
200
+
201
+
202
+ def main():
203
+ import sys
204
+ sys.path.append('.')
205
+
206
+ from environment.env import CodeReviewEnv
207
+
208
+ parser = argparse.ArgumentParser()
209
+ parser.add_argument("--task-id", type=str, default="bug_detection_easy",
210
+ help="Task ID to run")
211
+ parser.add_argument("--max-steps", type=int, default=50,
212
+ help="Maximum steps per episode")
213
+ args = parser.parse_args()
214
+
215
+ env = CodeReviewEnv()
216
+ env.max_steps = args.max_steps
217
+
218
+ agent = CodeReviewAgent()
219
+
220
+ obs = env.reset(task_id=args.task_id)
221
+ done = False
222
+ step = 0
223
+ total_reward = 0.0
224
+
225
+ print(f"Starting code review for task: {args.task_id}")
226
+ print(f"Task description: {obs.get('task_description', 'N/A')}")
227
+ print("-" * 50)
228
+
229
+ while not done and step < args.max_steps:
230
+ action_str = agent.get_action(obs)
231
+ action = agent.parse_action(action_str)
232
+
233
+ obs, reward, done, info = env.step(action)
234
+
235
+ total_reward += reward
236
+ step += 1
237
+
238
+ print(f"Step {step}:")
239
+ print(f" Action: {action.get('action_type')}")
240
+ print(f" Comments added: {len(action.get('comments', []))}")
241
+ print(f" Suggestions: {len(action.get('suggestions', []))}")
242
+ print(f" Reward: {reward:.3f}")
243
+ print(f" Cumulative: {total_reward:.3f}")
244
+ print(f" Done: {done}")
245
+
246
+ if info.get('last_action_valid') is False:
247
+ print(f" Warning: Invalid action! {info.get('error', '')}")
248
+
249
+ print("-" * 50)
250
+
251
+ final_score = env.get_task_score()
252
+ print(f"\nFinal Results:")
253
+ print(f" Total reward: {total_reward:.3f}")
254
+ print(f" Task score: {final_score:.3f}/1.0")
255
+ print(f" Steps taken: {step}")
256
+
257
+ env.close()
258
+
259
+ # Save results
260
+ results = {
261
+ "task_id": args.task_id,
262
+ "total_reward": total_reward,
263
+ "task_score": final_score,
264
+ "steps": step,
265
+ "max_steps": args.max_steps
266
+ }
267
+
268
+ with open("baseline_results.json", "w") as f:
269
+ json.dump(results, f, indent=2)
270
+
271
+ print("\nResults saved to baseline_results.json")
272
+
273
+
274
+ if __name__ == "__main__":
275
+ if not API_BASE_URL:
276
+ print("ERROR: API_BASE_URL environment variable not set")
277
+ print("Example: export API_BASE_URL=https://generativelanguage.googleapis.com")
278
+ sys.exit(1)
279
+
280
+ if not API_KEY:
281
+ print("ERROR: API_KEY environment variable not set")
282
+ print("Example: export API_KEY=your-api-key-here")
283
+ sys.exit(1)
284
+
285
+ if not MODEL_NAME:
286
+ print("ERROR: MODEL_NAME environment variable not set")
287
+ print("Example: export MODEL_NAME=gemini-1.5-pro")
288
+ sys.exit(1)
289
+
290
+ main()