| """ |
| Shared vLLM client and tier-aware model calling utilities. |
| |
| All LLM inference across OncoAgent flows through this module, |
| ensuring consistent model selection, error handling, and |
| environment variable management. |
| |
| Design inspired by: |
| - Hermes Agent: structured tool calling with JSON output |
| - Claude Code: deterministic harness separating LLM from execution |
| |
| Production target: AMD Instinct MI300X via ROCm 7.2 + vLLM |
| Development fallback: Featherless.ai OpenAI-compatible API |
| """ |
|
|
| import os |
| import re |
| import logging |
| from typing import Optional, Dict, Any, List |
| from dataclasses import dataclass |
|
|
| from openai import OpenAI |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass(frozen=True) |
| class TierSpec: |
| """Immutable specification for a model tier.""" |
|
|
| tier_id: int |
| name: str |
| model_id: str |
| description: str |
| max_tokens: int |
| temperature: float |
|
|
| def __str__(self) -> str: |
| return f"Tier {self.tier_id}: {self.name} ({self.model_id})" |
|
|
|
|
| |
| TIER_SPECS: Dict[int, TierSpec] = { |
| 1: TierSpec( |
| tier_id=1, |
| name="Speed Triage", |
| model_id=os.getenv("TIER1_MODEL_ID", "Qwen/Qwen3.5-9B"), |
| description="Fast model for initial triage and low-complexity cases.", |
| max_tokens=2048, |
| temperature=0.1, |
| ), |
| 2: TierSpec( |
| tier_id=2, |
| name="Deep Reasoning", |
| model_id=os.getenv("TIER2_MODEL_ID", "Qwen/Qwen3.6-27B"), |
| description="High-reasoning model for complex oncology cases and validation.", |
| max_tokens=4096, |
| temperature=0.0, |
| ), |
| } |
|
|
|
|
| |
| |
| |
|
|
| _THINK_PATTERN = re.compile(r"<think>.*?</think>", re.DOTALL) |
|
|
|
|
| def _strip_thinking_tokens(text: str) -> str: |
| """Remove Qwen3 <think>...</think> blocks from model output. |
| |
| Qwen3 models use an internal reasoning mode that wraps chain-of-thought |
| in <think> tags. We preserve only the final answer for the pipeline. |
| |
| Args: |
| text: Raw model output potentially containing <think> blocks. |
| |
| Returns: |
| Cleaned text with thinking blocks removed. |
| """ |
| cleaned = _THINK_PATTERN.sub("", text).strip() |
| |
| return cleaned if cleaned else text.strip() |
|
|
|
|
| |
| |
| |
|
|
| _vllm_client: Optional[OpenAI] = None |
|
|
|
|
| def get_vllm_client() -> OpenAI: |
| """Return a cached OpenAI-compatible client pointing at vLLM. |
| |
| Reads ``VLLM_API_BASE`` and ``VLLM_API_KEY`` from environment. |
| |
| Returns: |
| OpenAI client configured for the local vLLM server. |
| """ |
| global _vllm_client |
| if _vllm_client is None: |
| api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1") |
| api_key = os.getenv("VLLM_API_KEY", "EMPTY") |
| _vllm_client = OpenAI(base_url=api_base, api_key=api_key) |
| logger.info("vLLM client initialised → %s", api_base) |
| return _vllm_client |
|
|
|
|
| |
| |
| |
|
|
| def _resolve_model_id(spec: TierSpec) -> str: |
| """Resolve the actual model ID to use for API calls. |
| |
| In production (local vLLM), we use the exact model ID. |
| In development (Featherless.ai), some models may not be available, |
| so we check for configured fallbacks. |
| |
| Args: |
| spec: The TierSpec for the requested tier. |
| |
| Returns: |
| The model ID string to pass to the API. |
| """ |
| api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1") |
| is_featherless = "featherless" in api_base.lower() |
|
|
| if is_featherless and spec.tier_id == 2: |
| |
| fallback = os.getenv("TIER2_FEATHERLESS_FALLBACK", "Qwen/Qwen3.5-27B") |
| logger.info( |
| "Featherless detected: Tier 2 fallback %s → %s", |
| spec.model_id, fallback, |
| ) |
| return fallback |
|
|
| return spec.model_id |
|
|
|
|
| |
| |
| |
|
|
| class LocalModelManager: |
| """Singleton to manage local LoRA model loading and inference. |
| |
| Only used on the AMD droplet with working ROCm/GPU drivers. |
| In development without GPU, this is skipped entirely. |
| """ |
|
|
| _instance = None |
|
|
| def __new__(cls): |
| if cls._instance is None: |
| cls._instance = super(LocalModelManager, cls).__new__(cls) |
| cls._instance.model = None |
| cls._instance.tokenizer = None |
| cls._instance.initialized = False |
| return cls._instance |
|
|
| def initialize(self) -> None: |
| """Load the base model and LoRA adapters.""" |
| if self.initialized: |
| return |
|
|
| try: |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
| import torch |
| except ImportError: |
| logger.warning("Transformers/PEFT/Torch not installed. Local inference disabled.") |
| return |
|
|
| adapter_path = os.getenv("LOCAL_ADAPTER_PATH") |
| base_model_id = os.getenv("BASE_MODEL_ID", "Qwen/Qwen3.5-9B") |
|
|
| if not adapter_path or not os.path.exists(adapter_path): |
| logger.error("Local adapter path not found: %s", adapter_path) |
| return |
|
|
| logger.info("Loading base model %s + adapters %s...", base_model_id, adapter_path) |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| base_model_id, trust_remote_code=True, |
| ) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| base_model_id, |
| dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| self.model = PeftModel.from_pretrained(base_model, adapter_path) |
| self.model.eval() |
| self.initialized = True |
| logger.info("Local BF16 model ready on %s", os.getenv("DEVICE", "cuda")) |
| except Exception as exc: |
| logger.error("Failed to load local model: %s", exc) |
|
|
| def generate( |
| self, |
| system_prompt: str, |
| user_prompt: str, |
| max_tokens: int, |
| temperature: float, |
| ) -> str: |
| """Run inference using the loaded local model.""" |
| if not self.initialized: |
| self.initialize() |
| if not self.initialized: |
| raise RuntimeError("Local model manager not initialized.") |
|
|
| import torch |
|
|
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ] |
| prompt_str = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True, |
| ) |
| inputs = self.tokenizer(text=prompt_str, return_tensors="pt").to("cuda") |
|
|
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=temperature > 0, |
| use_cache=True, |
| pad_token_id=self.tokenizer.pad_token_id, |
| ) |
| generated_ids = outputs[:, inputs.input_ids.shape[1]:] |
| response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| return _strip_thinking_tokens(response) |
|
|
|
|
| _local_manager = LocalModelManager() |
|
|
|
|
| |
| |
| |
|
|
| def call_tier_model( |
| tier: int, |
| system_prompt: str, |
| user_prompt: str, |
| max_tokens: Optional[int] = None, |
| temperature: Optional[float] = None, |
| json_mode: bool = False, |
| ) -> str: |
| """Call the appropriate model based on the selected tier. |
| |
| This is the *single entry point* for all LLM inference in OncoAgent. |
| Every node must call this function instead of instantiating clients. |
| |
| Flow: |
| 1. If USE_LOCAL_ADAPTERS=true AND tier=1 → try local PEFT inference |
| 2. If local fails or not enabled → route through vLLM/Featherless API |
| |
| Args: |
| tier: Model tier (1 = fast 9B, 2 = deep 27B). |
| system_prompt: System-level instructions. |
| user_prompt: User-level content / query. |
| max_tokens: Override the tier's default max_tokens. |
| temperature: Override the tier's default temperature. |
| json_mode: If True, request JSON response format. |
| |
| Returns: |
| The model's text response (stripped of thinking tokens). |
| |
| Raises: |
| ValueError: If the tier is not 1 or 2. |
| RuntimeError: If the vLLM server is unreachable. |
| """ |
| spec = TIER_SPECS.get(tier) |
| if spec is None: |
| raise ValueError(f"Invalid tier {tier}. Must be 1 or 2.") |
|
|
| effective_max_tokens = max_tokens or spec.max_tokens |
| effective_temperature = temperature if temperature is not None else spec.temperature |
|
|
| logger.info( |
| "Calling %s (max_tokens=%d, temp=%.2f, json=%s)", |
| spec, effective_max_tokens, effective_temperature, json_mode, |
| ) |
|
|
| |
| use_local = os.getenv("USE_LOCAL_ADAPTERS", "false").lower() == "true" |
| if tier == 1 and use_local: |
| try: |
| logger.info("Routing Tier 1 to local LoRA adapters...") |
| return _local_manager.generate( |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| max_tokens=effective_max_tokens, |
| temperature=effective_temperature, |
| ) |
| except Exception as local_exc: |
| logger.warning("Local inference failed, falling back to API: %s", local_exc) |
|
|
| |
| model_id = _resolve_model_id(spec) |
|
|
| try: |
| client = get_vllm_client() |
|
|
| kwargs: Dict[str, Any] = { |
| "model": model_id, |
| "messages": [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| "temperature": effective_temperature, |
| "max_tokens": effective_max_tokens, |
| } |
|
|
| if json_mode: |
| kwargs["response_format"] = {"type": "json_object"} |
|
|
| response = client.chat.completions.create(**kwargs) |
| raw_text = response.choices[0].message.content or "" |
| text = _strip_thinking_tokens(raw_text) |
|
|
| if not text: |
| logger.warning( |
| "Model returned empty response (raw_len=%d). " |
| "May be all <think> tokens. Returning raw.", |
| len(raw_text), |
| ) |
| text = raw_text.strip() if raw_text else "" |
|
|
| logger.debug("Response length: %d chars", len(text)) |
| return text |
|
|
| except Exception as exc: |
| logger.error("vLLM call failed for %s: %s", spec, exc) |
| raise RuntimeError( |
| f"Error connecting to vLLM ({model_id}): {exc}" |
| ) from exc |
|
|
|
|
| def get_tier_spec(tier: int) -> TierSpec: |
| """Retrieve the TierSpec for the given tier number. |
| |
| Args: |
| tier: 1 or 2. |
| |
| Returns: |
| The corresponding TierSpec. |
| """ |
| spec = TIER_SPECS.get(tier) |
| if spec is None: |
| raise ValueError(f"Invalid tier {tier}. Available: {list(TIER_SPECS.keys())}") |
| return spec |
|
|