Athmabhiram1 commited on
Commit
8d25213
·
1 Parent(s): 299ab86

fix: allow inference training runs without local gguf weights in hosted env

Browse files
Files changed (1) hide show
  1. code-review-env/inference.py +69 -41
code-review-env/inference.py CHANGED
@@ -17,6 +17,16 @@ from training.run_manager import TrainingRunManager
17
  from training.weights import WeightSafetyManager
18
 
19
 
 
 
 
 
 
 
 
 
 
 
20
  def _build_parser() -> argparse.ArgumentParser:
21
  parser = argparse.ArgumentParser(description="GraphReview deterministic inference/training harness")
22
  parser.add_argument("target", help="Path to target Python project")
@@ -79,9 +89,9 @@ def _build_agent_prompt(module_id: str, code: str, ast_summary: str) -> str:
79
 
80
 
81
  def _extract_agent_findings(store: Store, config) -> set[str]:
82
- model = config.llm_model_training
83
- base_url = config.llm_base_url
84
- api_key = config.llm_api_key
85
  enabled = os.getenv("GRAPHREVIEW_AGENT_INFERENCE_ENABLED", "true").strip().lower() == "true"
86
 
87
  findings: set[str] = set()
@@ -166,49 +176,65 @@ def _extract_agent_findings(store: Store, config) -> set[str]:
166
  def main() -> None:
167
  args = _build_parser().parse_args()
168
  config = load_runtime_config()
169
- model_name = os.getenv("MODEL_NAME", "gemma4:e4b")
170
 
171
  target = Path(args.target).resolve()
172
- print(f"[START] target={target} model={config.llm_model_training} mode=deterministic-ground-truth")
173
 
174
  weight_manager = WeightSafetyManager(Path(config.llm_weight_manifest_dir))
 
175
  if args.register_weights:
176
- manifest = weight_manager.register_existing(
177
- model_name=model_name,
178
- weight_path=Path(config.llm_model_agent_path),
179
- )
180
- print(
181
- "[STEP] weights_registered "
182
- + json.dumps(
183
- {
184
- "model": manifest.model_name,
185
- "sha256": manifest.sha256,
186
- "size_bytes": manifest.size_bytes,
187
- },
188
- sort_keys=True,
 
 
 
 
 
 
 
189
  )
190
- )
191
 
192
  try:
193
- verified_weight_path = weight_manager.load_verified(model_name)
194
  except FileNotFoundError:
195
- manifest = weight_manager.register_existing(
196
- model_name=model_name,
197
- weight_path=Path(config.llm_model_agent_path),
198
- )
199
- print(
200
- "[STEP] weights_registered "
201
- + json.dumps(
202
- {
203
- "model": manifest.model_name,
204
- "sha256": manifest.sha256,
205
- "size_bytes": manifest.size_bytes,
206
- },
207
- sort_keys=True,
208
  )
209
- )
210
- verified_weight_path = weight_manager.load_verified(model_name)
211
- print(f"[STEP] weights_verified path={verified_weight_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  seed_result = seed_project(target_dir=target, db_path=args.db_path, force=args.force_seed)
214
  print(f"[STEP] seeded {json.dumps(seed_result, sort_keys=True)}")
@@ -309,17 +335,19 @@ def main() -> None:
309
  run_id = f"tr-{datetime.now(UTC).strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:8]}"
310
  run_config = {
311
  "target": str(target),
312
- "model": config.llm_model_agent,
313
  "model_path": config.llm_model_agent_path,
314
  "agent_inference_enabled": os.getenv("GRAPHREVIEW_AGENT_INFERENCE_ENABLED", "true"),
315
  "regression_tolerance": args.regression_tolerance,
316
  "baseline_precision": baseline_precision,
317
  "baseline_recall": baseline_recall,
318
  }
319
- sha256 = weight_manager.checksum(Path(verified_weight_path))
 
 
320
  store.create_training_run(
321
  run_id=run_id,
322
- model_name=config.llm_model_training,
323
  model_sha256=sha256,
324
  deterministic_findings=len(deterministic_keys),
325
  agent_findings=len(agent_keys),
@@ -341,8 +369,8 @@ def main() -> None:
341
  "ok": True,
342
  "deterministic_findings": len(deterministic_findings),
343
  "agent_findings": len(agent_keys),
344
- "model_weight": str(verified_weight_path),
345
- "model": config.llm_model_training,
346
  "precision": comparison.precision,
347
  "recall": comparison.recall,
348
  "run_id": run_id,
 
17
  from training.weights import WeightSafetyManager
18
 
19
 
20
+ # Submission-required runtime variables.
21
+ API_BASE_URL = os.getenv("API_BASE_URL", os.getenv("GRAPHREVIEW_LLM_BASE_URL", "http://localhost:11434/v1"))
22
+ MODEL_NAME = os.getenv("MODEL_NAME", "gemma4:e4b")
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
25
+
26
+ # Keep current behavior for local Ollama while supporting future hosted providers via HF_TOKEN.
27
+ API_KEY = HF_TOKEN or os.getenv("GRAPHREVIEW_LLM_API_KEY", "ollama")
28
+
29
+
30
  def _build_parser() -> argparse.ArgumentParser:
31
  parser = argparse.ArgumentParser(description="GraphReview deterministic inference/training harness")
32
  parser.add_argument("target", help="Path to target Python project")
 
89
 
90
 
91
  def _extract_agent_findings(store: Store, config) -> set[str]:
92
+ model = MODEL_NAME
93
+ base_url = API_BASE_URL
94
+ api_key = API_KEY
95
  enabled = os.getenv("GRAPHREVIEW_AGENT_INFERENCE_ENABLED", "true").strip().lower() == "true"
96
 
97
  findings: set[str] = set()
 
176
  def main() -> None:
177
  args = _build_parser().parse_args()
178
  config = load_runtime_config()
 
179
 
180
  target = Path(args.target).resolve()
181
+ print(f"[START] target={target} model={MODEL_NAME} mode=deterministic-ground-truth")
182
 
183
  weight_manager = WeightSafetyManager(Path(config.llm_weight_manifest_dir))
184
+ verified_weight_path: str | None = None
185
  if args.register_weights:
186
+ try:
187
+ manifest = weight_manager.register_existing(
188
+ model_name=MODEL_NAME,
189
+ weight_path=Path(config.llm_model_agent_path),
190
+ )
191
+ print(
192
+ "[STEP] weights_registered "
193
+ + json.dumps(
194
+ {
195
+ "model": manifest.model_name,
196
+ "sha256": manifest.sha256,
197
+ "size_bytes": manifest.size_bytes,
198
+ },
199
+ sort_keys=True,
200
+ )
201
+ )
202
+ except FileNotFoundError:
203
+ print(
204
+ f"[STEP] weights_register_skipped reason=missing-local-weights model={MODEL_NAME} "
205
+ f"path={config.llm_model_agent_path}"
206
  )
 
207
 
208
  try:
209
+ verified_weight_path = str(weight_manager.load_verified(MODEL_NAME))
210
  except FileNotFoundError:
211
+ try:
212
+ manifest = weight_manager.register_existing(
213
+ model_name=MODEL_NAME,
214
+ weight_path=Path(config.llm_model_agent_path),
 
 
 
 
 
 
 
 
 
215
  )
216
+ print(
217
+ "[STEP] weights_registered "
218
+ + json.dumps(
219
+ {
220
+ "model": manifest.model_name,
221
+ "sha256": manifest.sha256,
222
+ "size_bytes": manifest.size_bytes,
223
+ },
224
+ sort_keys=True,
225
+ )
226
+ )
227
+ verified_weight_path = str(weight_manager.load_verified(MODEL_NAME))
228
+ except FileNotFoundError:
229
+ print(
230
+ f"[STEP] weights_unavailable reason=missing-local-weights model={MODEL_NAME} "
231
+ f"path={config.llm_model_agent_path}"
232
+ )
233
+
234
+ if verified_weight_path is not None:
235
+ print(f"[STEP] weights_verified path={verified_weight_path}")
236
+ else:
237
+ print("[STEP] weights_verified path=unavailable mode=api-only")
238
 
239
  seed_result = seed_project(target_dir=target, db_path=args.db_path, force=args.force_seed)
240
  print(f"[STEP] seeded {json.dumps(seed_result, sort_keys=True)}")
 
335
  run_id = f"tr-{datetime.now(UTC).strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:8]}"
336
  run_config = {
337
  "target": str(target),
338
+ "model": MODEL_NAME,
339
  "model_path": config.llm_model_agent_path,
340
  "agent_inference_enabled": os.getenv("GRAPHREVIEW_AGENT_INFERENCE_ENABLED", "true"),
341
  "regression_tolerance": args.regression_tolerance,
342
  "baseline_precision": baseline_precision,
343
  "baseline_recall": baseline_recall,
344
  }
345
+ sha256 = "unavailable"
346
+ if verified_weight_path is not None:
347
+ sha256 = weight_manager.checksum(Path(verified_weight_path))
348
  store.create_training_run(
349
  run_id=run_id,
350
+ model_name=MODEL_NAME,
351
  model_sha256=sha256,
352
  deterministic_findings=len(deterministic_keys),
353
  agent_findings=len(agent_keys),
 
369
  "ok": True,
370
  "deterministic_findings": len(deterministic_findings),
371
  "agent_findings": len(agent_keys),
372
+ "model_weight": verified_weight_path or "unavailable",
373
+ "model": MODEL_NAME,
374
  "precision": comparison.precision,
375
  "recall": comparison.recall,
376
  "run_id": run_id,