Nitish commited on
Commit
98bf903
Β·
1 Parent(s): 9e52d37

fix: safe client init + pyproject.toml server entry point

Browse files
Files changed (1) hide show
  1. inference.py +36 -21
inference.py CHANGED
@@ -28,11 +28,6 @@ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
28
  ENV_URL = os.getenv("ENV_URL") or "http://localhost:7860"
29
  BENCHMARK = "code-security-review"
30
 
31
- if not HF_TOKEN:
32
- raise ValueError("HF_TOKEN or API_KEY must be set.")
33
-
34
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
35
-
36
  SYSTEM_PROMPT = """You are a senior security-focused code reviewer.
37
 
38
  When given a code snippet, carefully analyse it for bugs and security issues.
@@ -108,7 +103,7 @@ def build_prompt(obs: dict) -> str:
108
 
109
  # ── Task runner ───────────────────────────────────────────────────────────────
110
 
111
- def run_task(task_id: str, task_num: int) -> dict:
112
  cumulative_reward = 0.0
113
  step_num = 0
114
  done = False
@@ -131,20 +126,32 @@ def run_task(task_id: str, task_num: int) -> dict:
131
 
132
  # ── LLM call ──────────────────────────────────────────────────────────
133
  try:
134
- response = client.chat.completions.create(
135
- model=MODEL_NAME,
136
- messages=[
137
- {"role": "system", "content": SYSTEM_PROMPT},
138
- {"role": "user", "content": prompt},
139
- ],
140
- temperature=0.1,
141
- max_tokens=600,
142
- stream=False,
143
- )
144
- raw = response.choices[0].message.content
145
- action_dict = parse_json_from_llm(raw)
146
- action_str = json.dumps(action_dict)
147
- error = None
 
 
 
 
 
 
 
 
 
 
 
 
148
  except Exception as exc:
149
  error = str(exc).replace("\n", " ")
150
  action_dict = {
@@ -187,6 +194,14 @@ def run_task(task_id: str, task_num: int) -> dict:
187
  def main():
188
  print(f"[INFO] Initializing inference on {BENCHMARK} using {MODEL_NAME}", flush=True)
189
 
 
 
 
 
 
 
 
 
190
  TASK_FILTER = os.environ.get("TASK")
191
 
192
  all_tasks = [
@@ -204,7 +219,7 @@ def main():
204
 
205
  for task_id, task_num, _ in tasks:
206
  try:
207
- r = run_task(task_id, task_num)
208
  except Exception as exc:
209
  print(f"[ERROR] task_id={task_id} error={exc}", flush=True)
210
  r = {"task_num": task_num, "task_id": task_id, "score": 0.0, "success": False}
 
28
  ENV_URL = os.getenv("ENV_URL") or "http://localhost:7860"
29
  BENCHMARK = "code-security-review"
30
 
 
 
 
 
 
31
  SYSTEM_PROMPT = """You are a senior security-focused code reviewer.
32
 
33
  When given a code snippet, carefully analyse it for bugs and security issues.
 
103
 
104
  # ── Task runner ───────────────────────────────────────────────────────────────
105
 
106
+ def run_task(task_id: str, task_num: int, client=None) -> dict:
107
  cumulative_reward = 0.0
108
  step_num = 0
109
  done = False
 
126
 
127
  # ── LLM call ──────────────────────────────────────────────────────────
128
  try:
129
+ if client is None:
130
+ action_dict = {
131
+ "bug_identified": True,
132
+ "bug_location": "unknown",
133
+ "bug_type": "security-vulnerability",
134
+ "bug_description": "Fallback deterministic action",
135
+ "severity": "high",
136
+ "suggested_fix": "Fix vulnerability"
137
+ }
138
+ action_str = json.dumps(action_dict)
139
+ error = None
140
+ else:
141
+ response = client.chat.completions.create(
142
+ model=MODEL_NAME,
143
+ messages=[
144
+ {"role": "system", "content": SYSTEM_PROMPT},
145
+ {"role": "user", "content": prompt},
146
+ ],
147
+ temperature=0.1,
148
+ max_tokens=600,
149
+ stream=False,
150
+ )
151
+ raw = response.choices[0].message.content
152
+ action_dict = parse_json_from_llm(raw)
153
+ action_str = json.dumps(action_dict)
154
+ error = None
155
  except Exception as exc:
156
  error = str(exc).replace("\n", " ")
157
  action_dict = {
 
194
  def main():
195
  print(f"[INFO] Initializing inference on {BENCHMARK} using {MODEL_NAME}", flush=True)
196
 
197
+ client = None
198
+ try:
199
+ if not HF_TOKEN:
200
+ raise ValueError("HF_TOKEN or API_KEY must be set.")
201
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
202
+ except Exception as exc:
203
+ print(f"[WARN] Client init failed: {exc}. Using deterministic fallback.", flush=True)
204
+
205
  TASK_FILTER = os.environ.get("TASK")
206
 
207
  all_tasks = [
 
219
 
220
  for task_id, task_num, _ in tasks:
221
  try:
222
+ r = run_task(task_id, task_num, client=client)
223
  except Exception as exc:
224
  print(f"[ERROR] task_id={task_id} error={exc}", flush=True)
225
  r = {"task_num": task_num, "task_id": task_id, "score": 0.0, "success": False}