| """ |
| Inference Script — TICE (Tumor Immune Control Environment) |
| ========================================================= |
| MANDATORY ENVIRONMENT VARIABLES: |
| API_BASE_URL The API endpoint for the LLM. |
| MODEL_NAME The model identifier to use for inference. |
| HF_TOKEN Your Hugging Face / API key. |
| LOCAL_IMAGE_NAME Docker image name for the environment. |
| |
| STDOUT FORMAT (strictly followed): |
| [START] task=<task_name> env=<benchmark> model=<model_name> |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> |
| [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn> |
| |
| This script runs 3 tasks (easy, medium, hard). Each task is a single multi-step episode: |
| reset() → repeatedly: LLM picks (B-cell action, T-cell action) → step() → log → done |
| |
| Final score per task = average reward across steps in that episode. |
| Overall score = average across all 3 tasks. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import os |
| import textwrap |
| from pathlib import Path |
| from typing import Any, List, Optional, Tuple |
|
|
| from dotenv import load_dotenv |
| from openai import OpenAI |
|
|
| try: |
| from tice import TICEAction, TICEEnv |
| from tice.models import B_CELL_ACTIONS, T_CELL_ACTIONS |
| except (ImportError, ModuleNotFoundError): |
| from client import TICEEnv |
| from models import B_CELL_ACTIONS, TICEAction, T_CELL_ACTIONS |
|
|
|
|
| |
| load_dotenv(Path(__file__).resolve().parent / ".env") |
|
|
| |
| IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME") |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") |
| BENCHMARK = "tice" |
|
|
| TEMPERATURE = float(os.getenv("TICE_LLM_TEMPERATURE", "0.2")) |
| MAX_TOKENS = int(os.getenv("TICE_LLM_MAX_TOKENS", "500")) |
|
|
| SUCCESS_SCORE_THRESHOLD = float(os.getenv("TICE_SUCCESS_SCORE_THRESHOLD", "0.0")) |
|
|
| TASKS: List[Tuple[str, str, str]] = [ |
| ("easy", "immune_cold", "easy"), |
| ("medium", "immune_hot", "medium"), |
| ("hard", "high_mutation", "hard"), |
| ] |
|
|
|
|
| SYSTEM_PROMPT = textwrap.dedent( |
| f""" |
| You control a tumor immune therapy simulator. On each turn you must choose exactly one |
| B-cell action and one T-cell action. |
| |
| Valid B-cell actions: |
| {", ".join(B_CELL_ACTIONS)} |
| |
| Valid T-cell actions: |
| {", ".join(T_CELL_ACTIONS)} |
| |
| Objective: |
| - Reduce and eradicate the tumor before timeout. |
| - Preserve energy and avoid excessive B-cell and T-cell fatigue. |
| - B cells improve detection. T cells do the damage. |
| - In early phase, overcommitting T cells before reliable detection is usually wasteful. |
| - If T-cell fatigue is high, recovery may be better than aggression. |
| |
| You must reply with a valid JSON object and nothing else: |
| {{ |
| "b_cell_action": "<one valid B-cell action>", |
| "t_cell_action": "<one valid T-cell action>", |
| "reasoning": "<brief reasoning>" |
| }} |
| |
| Rules: |
| - Use only the valid action strings. |
| - Base decisions only on the provided observation. |
| - Keep reasoning short. |
| """ |
| ).strip() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def log_start(task: str, env: str, model: str) -> None: |
| print(f"[START] task={task} env={env} model={model}", flush=True) |
|
|
|
|
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: |
| error_val = error if error else "null" |
| done_val = str(done).lower() |
| action_clean = action.replace("\n", " ").replace("\r", "")[:120] |
| print( |
| f"[STEP] step={step} action={action_clean} reward={reward:.2f} " |
| f"done={done_val} error={error_val}", |
| flush=True, |
| ) |
|
|
|
|
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| print( |
| f"[END] success={str(success).lower()} steps={steps} " |
| f"score={score:.3f} rewards={rewards_str}", |
| flush=True, |
| ) |
|
|
|
|
| def heuristic_action(observation) -> tuple[str, str]: |
| if observation.episode_phase == "early": |
| return "INCREASE_HIGH", "REST" |
| if observation.detection_signal < 0.4: |
| return "INCREASE_LOW", "ATTACK_LOW" |
| if observation.t_cell_fatigue > 0.6: |
| return "MAINTAIN", "REST" |
| if observation.tumor_trend == "increasing": |
| return "MAINTAIN", "ATTACK_MEDIUM" |
| return "MAINTAIN", "ATTACK_LOW" |
|
|
|
|
| def build_user_prompt(observation) -> str: |
| return textwrap.dedent( |
| f""" |
| Current episode context: |
| - archetype: {observation.archetype} |
| - difficulty: {observation.difficulty} |
| - timestep: {observation.timestep} |
| - episode_phase: {observation.episode_phase} |
| - tumor_trend: {observation.tumor_trend} |
| - detection_signal: {observation.detection_signal} |
| - t_cell_effectiveness: {observation.t_cell_effectiveness} |
| - resource_level: {observation.resource_level} |
| - b_cell_fatigue: {observation.b_cell_fatigue} |
| - t_cell_fatigue: {observation.t_cell_fatigue} |
| - recent_outcome: {observation.recent_outcome} |
| - feedback: {observation.feedback} |
| |
| Choose the next B-cell and T-cell actions. |
| Respond with JSON only. |
| """ |
| ).strip() |
|
|
|
|
| def sanitize_json_response(raw_response: str) -> str: |
| cleaned = raw_response.strip() |
| if cleaned.startswith("```"): |
| lines = cleaned.splitlines() |
| cleaned = "\n".join( |
| line for line in lines if not line.strip().startswith("```") |
| ).strip() |
| return cleaned |
|
|
|
|
| def coerce_action(raw_action: Any, valid_actions: list[str], fallback: str) -> str: |
| if not isinstance(raw_action, str): |
| return fallback |
|
|
| normalized = raw_action.strip().upper().replace("-", "_").replace(" ", "_") |
| if normalized in valid_actions: |
| return normalized |
|
|
| for candidate in valid_actions: |
| if normalized == candidate.upper(): |
| return candidate |
|
|
| return fallback |
|
|
|
|
| def get_llm_action(client: OpenAI, observation) -> tuple[TICEAction, str]: |
| fallback_b, fallback_t = heuristic_action(observation) |
| raw_response = "" |
|
|
| try: |
| completion = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": build_user_prompt(observation)}, |
| ], |
| temperature=TEMPERATURE, |
| max_tokens=MAX_TOKENS, |
| ) |
| raw_response = completion.choices[0].message.content or "" |
| parsed = json.loads(sanitize_json_response(raw_response)) |
|
|
| b_action = coerce_action( |
| parsed.get("b_cell_action"), |
| B_CELL_ACTIONS, |
| fallback_b, |
| ) |
| t_action = coerce_action( |
| parsed.get("t_cell_action"), |
| T_CELL_ACTIONS, |
| fallback_t, |
| ) |
| reasoning = str(parsed.get("reasoning", "")).strip() or "no_reasoning" |
|
|
| return TICEAction(b_cell_action=b_action, t_cell_action=t_action), reasoning |
| except Exception as exc: |
| fallback_action = TICEAction( |
| b_cell_action=fallback_b, |
| t_cell_action=fallback_t, |
| ) |
| return fallback_action, f"fallback:{type(exc).__name__}" |
|
|
|
|
| def require_api_key() -> str: |
| if API_KEY: |
| return API_KEY |
| raise RuntimeError( |
| "Missing API key. Set HF_TOKEN, API_KEY, or OPENAI_API_KEY before running inference_llm.py." |
| ) |
|
|
| def require_image_name() -> str: |
| if IMAGE_NAME: |
| return IMAGE_NAME |
| raise RuntimeError( |
| "Missing docker image name. Set LOCAL_IMAGE_NAME (or IMAGE_NAME) before running inference_llm.py." |
| ) |
|
|
|
|
| async def run_task(task: str, archetype: str, difficulty: str, client: OpenAI) -> float: |
| log_start(task=task, env=BENCHMARK, model=MODEL_NAME) |
|
|
| env = await TICEEnv.from_docker_image(require_image_name()) |
| rewards: List[float] = [] |
| steps = 0 |
| score = 0.0 |
| success = False |
|
|
| try: |
| result = await env.reset(archetype=archetype, difficulty=difficulty) |
| obs = result.observation |
|
|
| while not obs.done: |
| action, reasoning = get_llm_action(client, obs) |
| result = await env.step(action) |
| obs = result.observation |
|
|
| reward = float(result.reward if result.reward is not None else 0.0) |
| done = bool(result.done) |
|
|
| steps += 1 |
| rewards.append(reward) |
|
|
| action_summary = ( |
| f"{action.b_cell_action}|{action.t_cell_action}|" |
| f"phase={obs.episode_phase}|trend={obs.tumor_trend}|note={reasoning[:40]}" |
| ) |
| log_step(step=steps, action=action_summary, reward=reward, done=done, error=None) |
|
|
| score = (sum(rewards) / len(rewards)) if rewards else 0.0 |
| score = round(float(score), 4) |
| success = score >= SUCCESS_SCORE_THRESHOLD |
|
|
| except Exception as e: |
| err = str(e)[:80] |
| if steps == 0: |
| log_step(step=1, action="error", reward=0.0, done=True, error=err) |
| rewards = [0.0] |
| steps = 1 |
| score = (sum(rewards) / len(rewards)) if rewards else 0.0 |
| score = round(float(score), 4) |
| success = False |
|
|
| finally: |
| try: |
| await env.close() |
| except Exception: |
| pass |
|
|
| log_end(success=success, steps=steps, score=score, rewards=rewards) |
|
|
| return float(score) |
|
|
|
|
| async def main() -> None: |
| _ = require_image_name() |
| llm_client = OpenAI(base_url=API_BASE_URL, api_key=require_api_key()) |
|
|
| task_scores: List[float] = [] |
| for task, archetype, difficulty in TASKS: |
| score = await run_task( |
| task=task, |
| archetype=archetype, |
| difficulty=difficulty, |
| client=llm_client, |
| ) |
| task_scores.append(score) |
|
|
| overall = sum(task_scores) / len(task_scores) if task_scores else 0.0 |
| print(f"[DEBUG] overall_score={overall:.3f}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|