WHOAM-EYE commited on
Commit
aee090e
·
verified ·
1 Parent(s): 38e5c55

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +47 -4
inference.py CHANGED
@@ -22,8 +22,10 @@ API_BASE_URL = os.getenv("API_BASE_URL")
22
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b")
23
  API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") or os.getenv("HF_TOKEN")
24
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "network-forensics-env:latest")
25
- ENV_MODE = (os.getenv("NETWORK_FORENSICS_ENV_MODE") or os.getenv("ENV_MODE") or "docker").lower()
26
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
 
 
27
  DOCKER_READY_TIMEOUT_S = float(os.getenv("DOCKER_READY_TIMEOUT_S", "120"))
28
  _ASYNC_LOOP: asyncio.AbstractEventLoop | None = None
29
 
@@ -52,10 +54,12 @@ def validate_config() -> None:
52
  missing.append("API_BASE_URL")
53
  if not API_KEY:
54
  missing.append("OPENAI_API_KEY/API_KEY/HF_TOKEN")
 
 
55
  if missing:
56
  raise RuntimeError(f"Missing required environment variables: {', '.join(missing)}")
57
- if ENV_MODE not in {"server", "docker"}:
58
- raise RuntimeError("NETWORK_FORENSICS_ENV_MODE must be one of: server, docker")
59
 
60
 
61
  def format_action(action: NetworkForensicsAction) -> str:
@@ -360,14 +364,53 @@ def resolve_maybe_awaitable(value: Any) -> Any:
360
 
361
 
362
  def create_env() -> NetworkForensicsEnv:
 
 
 
 
 
 
 
363
  if ENV_MODE == "docker":
364
  provider = ExtendedWaitDockerProvider()
365
  return resolve_maybe_awaitable(
366
  NetworkForensicsEnv.from_docker_image(LOCAL_IMAGE_NAME, provider=provider)
367
  )
 
 
 
 
368
  return NetworkForensicsEnv(base_url=ENV_BASE_URL)
369
 
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  def reset_env(env: NetworkForensicsEnv, task_name: str) -> Any:
372
  result = resolve_maybe_awaitable(env.reset(task_id=task_name))
373
  return result
@@ -405,7 +448,7 @@ def run_task(task_name: str) -> None:
405
  print(f"[START] task={task_name} env=network_forensics model={MODEL_NAME}")
406
 
407
  try:
408
- env = create_env()
409
  reset_result = reset_env(env, task_name)
410
  obs = reset_result.observation
411
  sync_agent_state(obs, agent_state)
 
22
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b")
23
  API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") or os.getenv("HF_TOKEN")
24
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "network-forensics-env:latest")
25
+ ENV_MODE = (os.getenv("NETWORK_FORENSICS_ENV_MODE") or os.getenv("ENV_MODE") or "hf").lower()
26
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
27
+ HF_SPACE_ID = os.getenv("HF_SPACE_ID") or os.getenv("SPACE_ID") or "WHOAM-EYE/network_forensics"
28
+ HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://whoam-eye-network-forensics.hf.space")
29
  DOCKER_READY_TIMEOUT_S = float(os.getenv("DOCKER_READY_TIMEOUT_S", "120"))
30
  _ASYNC_LOOP: asyncio.AbstractEventLoop | None = None
31
 
 
54
  missing.append("API_BASE_URL")
55
  if not API_KEY:
56
  missing.append("OPENAI_API_KEY/API_KEY/HF_TOKEN")
57
+ if ENV_MODE == "hf" and not (HF_SPACE_URL or HF_SPACE_ID):
58
+ missing.append("HF_SPACE_URL or HF_SPACE_ID/SPACE_ID")
59
  if missing:
60
  raise RuntimeError(f"Missing required environment variables: {', '.join(missing)}")
61
+ if ENV_MODE not in {"server", "docker", "hf"}:
62
+ raise RuntimeError("NETWORK_FORENSICS_ENV_MODE must be one of: server, docker, hf")
63
 
64
 
65
  def format_action(action: NetworkForensicsAction) -> str:
 
364
 
365
 
366
  def create_env() -> NetworkForensicsEnv:
367
+ # Preferred path: Hugging Face Space.
368
+ if ENV_MODE == "hf":
369
+ if HF_SPACE_URL:
370
+ return NetworkForensicsEnv(base_url=HF_SPACE_URL.rstrip("/"))
371
+ space_slug = HF_SPACE_ID.lower().replace("/", "-").replace("_", "-")
372
+ return NetworkForensicsEnv(base_url=f"https://{space_slug}.hf.space")
373
+
374
  if ENV_MODE == "docker":
375
  provider = ExtendedWaitDockerProvider()
376
  return resolve_maybe_awaitable(
377
  NetworkForensicsEnv.from_docker_image(LOCAL_IMAGE_NAME, provider=provider)
378
  )
379
+
380
+ if ENV_MODE == "server":
381
+ return NetworkForensicsEnv(base_url=ENV_BASE_URL)
382
+
383
  return NetworkForensicsEnv(base_url=ENV_BASE_URL)
384
 
385
 
386
+ def create_env_with_fallback() -> NetworkForensicsEnv:
387
+ # 1) Try HF Space.
388
+ try:
389
+ env = NetworkForensicsEnv(base_url=HF_SPACE_URL.rstrip("/"))
390
+ _ = reset_env(env, "easy")
391
+ return env
392
+ except Exception as exc:
393
+ print(f"[WARN] HF space failed ({exc}); trying Docker.")
394
+
395
+ # 2) Try Docker.
396
+ try:
397
+ provider = ExtendedWaitDockerProvider()
398
+ env = resolve_maybe_awaitable(
399
+ NetworkForensicsEnv.from_docker_image(LOCAL_IMAGE_NAME, provider=provider)
400
+ )
401
+ _ = reset_env(env, "easy")
402
+ return env
403
+ except Exception as exc:
404
+ print(f"[WARN] Docker failed ({exc}); trying local server.")
405
+
406
+ # 3) Last resort: in-process environment.
407
+ try:
408
+ from server.network_forensics_environment import NetworkForensicsEnvironment
409
+ return NetworkForensicsEnvironment(task_id="easy") # type: ignore[return-value]
410
+ except Exception as exc:
411
+ raise RuntimeError(f"All environment backends failed: {exc}") from exc
412
+
413
+
414
  def reset_env(env: NetworkForensicsEnv, task_name: str) -> Any:
415
  result = resolve_maybe_awaitable(env.reset(task_id=task_name))
416
  return result
 
448
  print(f"[START] task={task_name} env=network_forensics model={MODEL_NAME}")
449
 
450
  try:
451
+ env = create_env_with_fallback()
452
  reset_result = reset_env(env, task_name)
453
  obs = reset_result.observation
454
  sync_agent_state(obs, agent_state)