Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- inference.py +47 -4
inference.py
CHANGED
|
@@ -22,8 +22,10 @@ API_BASE_URL = os.getenv("API_BASE_URL")
|
|
| 22 |
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b")
|
| 23 |
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") or os.getenv("HF_TOKEN")
|
| 24 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "network-forensics-env:latest")
|
| 25 |
-
ENV_MODE = (os.getenv("NETWORK_FORENSICS_ENV_MODE") or os.getenv("ENV_MODE") or "
|
| 26 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
|
|
|
|
|
|
|
| 27 |
DOCKER_READY_TIMEOUT_S = float(os.getenv("DOCKER_READY_TIMEOUT_S", "120"))
|
| 28 |
_ASYNC_LOOP: asyncio.AbstractEventLoop | None = None
|
| 29 |
|
|
@@ -52,10 +54,12 @@ def validate_config() -> None:
|
|
| 52 |
missing.append("API_BASE_URL")
|
| 53 |
if not API_KEY:
|
| 54 |
missing.append("OPENAI_API_KEY/API_KEY/HF_TOKEN")
|
|
|
|
|
|
|
| 55 |
if missing:
|
| 56 |
raise RuntimeError(f"Missing required environment variables: {', '.join(missing)}")
|
| 57 |
-
if ENV_MODE not in {"server", "docker"}:
|
| 58 |
-
raise RuntimeError("NETWORK_FORENSICS_ENV_MODE must be one of: server, docker")
|
| 59 |
|
| 60 |
|
| 61 |
def format_action(action: NetworkForensicsAction) -> str:
|
|
@@ -360,14 +364,53 @@ def resolve_maybe_awaitable(value: Any) -> Any:
|
|
| 360 |
|
| 361 |
|
| 362 |
def create_env() -> NetworkForensicsEnv:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
if ENV_MODE == "docker":
|
| 364 |
provider = ExtendedWaitDockerProvider()
|
| 365 |
return resolve_maybe_awaitable(
|
| 366 |
NetworkForensicsEnv.from_docker_image(LOCAL_IMAGE_NAME, provider=provider)
|
| 367 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
return NetworkForensicsEnv(base_url=ENV_BASE_URL)
|
| 369 |
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
def reset_env(env: NetworkForensicsEnv, task_name: str) -> Any:
|
| 372 |
result = resolve_maybe_awaitable(env.reset(task_id=task_name))
|
| 373 |
return result
|
|
@@ -405,7 +448,7 @@ def run_task(task_name: str) -> None:
|
|
| 405 |
print(f"[START] task={task_name} env=network_forensics model={MODEL_NAME}")
|
| 406 |
|
| 407 |
try:
|
| 408 |
-
env =
|
| 409 |
reset_result = reset_env(env, task_name)
|
| 410 |
obs = reset_result.observation
|
| 411 |
sync_agent_state(obs, agent_state)
|
|
|
|
| 22 |
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b")
|
| 23 |
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") or os.getenv("HF_TOKEN")
|
| 24 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "network-forensics-env:latest")
|
| 25 |
+
ENV_MODE = (os.getenv("NETWORK_FORENSICS_ENV_MODE") or os.getenv("ENV_MODE") or "hf").lower()
|
| 26 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
|
| 27 |
+
HF_SPACE_ID = os.getenv("HF_SPACE_ID") or os.getenv("SPACE_ID") or "WHOAM-EYE/network_forensics"
|
| 28 |
+
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://whoam-eye-network-forensics.hf.space")
|
| 29 |
DOCKER_READY_TIMEOUT_S = float(os.getenv("DOCKER_READY_TIMEOUT_S", "120"))
|
| 30 |
_ASYNC_LOOP: asyncio.AbstractEventLoop | None = None
|
| 31 |
|
|
|
|
| 54 |
missing.append("API_BASE_URL")
|
| 55 |
if not API_KEY:
|
| 56 |
missing.append("OPENAI_API_KEY/API_KEY/HF_TOKEN")
|
| 57 |
+
if ENV_MODE == "hf" and not (HF_SPACE_URL or HF_SPACE_ID):
|
| 58 |
+
missing.append("HF_SPACE_URL or HF_SPACE_ID/SPACE_ID")
|
| 59 |
if missing:
|
| 60 |
raise RuntimeError(f"Missing required environment variables: {', '.join(missing)}")
|
| 61 |
+
if ENV_MODE not in {"server", "docker", "hf"}:
|
| 62 |
+
raise RuntimeError("NETWORK_FORENSICS_ENV_MODE must be one of: server, docker, hf")
|
| 63 |
|
| 64 |
|
| 65 |
def format_action(action: NetworkForensicsAction) -> str:
|
|
|
|
| 364 |
|
| 365 |
|
| 366 |
def create_env() -> NetworkForensicsEnv:
|
| 367 |
+
# Preferred path: Hugging Face Space.
|
| 368 |
+
if ENV_MODE == "hf":
|
| 369 |
+
if HF_SPACE_URL:
|
| 370 |
+
return NetworkForensicsEnv(base_url=HF_SPACE_URL.rstrip("/"))
|
| 371 |
+
space_slug = HF_SPACE_ID.lower().replace("/", "-").replace("_", "-")
|
| 372 |
+
return NetworkForensicsEnv(base_url=f"https://{space_slug}.hf.space")
|
| 373 |
+
|
| 374 |
if ENV_MODE == "docker":
|
| 375 |
provider = ExtendedWaitDockerProvider()
|
| 376 |
return resolve_maybe_awaitable(
|
| 377 |
NetworkForensicsEnv.from_docker_image(LOCAL_IMAGE_NAME, provider=provider)
|
| 378 |
)
|
| 379 |
+
|
| 380 |
+
if ENV_MODE == "server":
|
| 381 |
+
return NetworkForensicsEnv(base_url=ENV_BASE_URL)
|
| 382 |
+
|
| 383 |
return NetworkForensicsEnv(base_url=ENV_BASE_URL)
|
| 384 |
|
| 385 |
|
| 386 |
+
def create_env_with_fallback() -> NetworkForensicsEnv:
|
| 387 |
+
# 1) Try HF Space.
|
| 388 |
+
try:
|
| 389 |
+
env = NetworkForensicsEnv(base_url=HF_SPACE_URL.rstrip("/"))
|
| 390 |
+
_ = reset_env(env, "easy")
|
| 391 |
+
return env
|
| 392 |
+
except Exception as exc:
|
| 393 |
+
print(f"[WARN] HF space failed ({exc}); trying Docker.")
|
| 394 |
+
|
| 395 |
+
# 2) Try Docker.
|
| 396 |
+
try:
|
| 397 |
+
provider = ExtendedWaitDockerProvider()
|
| 398 |
+
env = resolve_maybe_awaitable(
|
| 399 |
+
NetworkForensicsEnv.from_docker_image(LOCAL_IMAGE_NAME, provider=provider)
|
| 400 |
+
)
|
| 401 |
+
_ = reset_env(env, "easy")
|
| 402 |
+
return env
|
| 403 |
+
except Exception as exc:
|
| 404 |
+
print(f"[WARN] Docker failed ({exc}); trying local server.")
|
| 405 |
+
|
| 406 |
+
# 3) Last resort: in-process environment.
|
| 407 |
+
try:
|
| 408 |
+
from server.network_forensics_environment import NetworkForensicsEnvironment
|
| 409 |
+
return NetworkForensicsEnvironment(task_id="easy") # type: ignore[return-value]
|
| 410 |
+
except Exception as exc:
|
| 411 |
+
raise RuntimeError(f"All environment backends failed: {exc}") from exc
|
| 412 |
+
|
| 413 |
+
|
| 414 |
def reset_env(env: NetworkForensicsEnv, task_name: str) -> Any:
|
| 415 |
result = resolve_maybe_awaitable(env.reset(task_id=task_name))
|
| 416 |
return result
|
|
|
|
| 448 |
print(f"[START] task={task_name} env=network_forensics model={MODEL_NAME}")
|
| 449 |
|
| 450 |
try:
|
| 451 |
+
env = create_env_with_fallback()
|
| 452 |
reset_result = reset_env(env, task_name)
|
| 453 |
obs = reset_result.observation
|
| 454 |
sync_agent_state(obs, agent_state)
|