File size: 3,513 Bytes
c0fff99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67651ce
 
 
 
 
 
 
c0fff99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Configuration for Diagnostic Devil's Advocate.
Controls model loading, quantization, and environment-specific settings.

Model loading priority:
  1. Local path (MODEL_LOCAL_DIR env var) — fully offline
  2. HF cache (auto-downloaded via huggingface-cli download) — offline after first download
  3. HF Hub (requires HF_TOKEN for gated models) — online fallback
"""

import os
from huggingface_hub import try_to_load_from_cache

# --- Model Selection ---
USE_27B = os.environ.get("USE_27B", "false").lower() == "true"
QUANTIZE_4B = os.environ.get("QUANTIZE_4B", "true").lower() == "true"
ENABLE_MEDASR = os.environ.get("ENABLE_MEDASR", "true").lower() == "true"

# --- Performance Optimization ---
# torch.compile: JIT 编译加速,首次推理慢(编译),后续快 30-80%
# 默认关闭:ZeroGPU 冷启动每次都要重新编译,不划算
ENABLE_TORCH_COMPILE = os.environ.get("ENABLE_TORCH_COMPILE", "false").lower() == "true"
# SDPA: 优化注意力计算,省显存 + 加速(无编译开销)
ENABLE_SDPA = os.environ.get("ENABLE_SDPA", "true").lower() == "true"

# --- Prompt Repetition (arXiv:2512.14982) ---
# Repeating the user prompt improves non-reasoning LLM performance (47 wins, 0 losses
# across 70 benchmark-model combos). Only increases prefill tokens, no extra generation.
ENABLE_PROMPT_REPETITION = os.environ.get("ENABLE_PROMPT_REPETITION", "true").lower() == "true"

# --- HF Token (for gated models) ---
# Loaded from: env var > huggingface-cli login stored token (auto)
HF_TOKEN = os.environ.get("HF_TOKEN", None)

# --- Model IDs (HF Hub) ---
_MEDGEMMA_4B_HUB_ID = "google/medgemma-1.5-4b-it"
_MEDGEMMA_27B_HUB_ID = "google/medgemma-27b-text-it"
_MEDSIGLIP_HUB_ID = "google/medsiglip-448"
_MEDASR_HUB_ID = "google/medasr"

# --- Optional local model directories (override HF Hub) ---
# Set these env vars to point to a local directory containing model weights.
# If not set, models load from HF cache (downloaded via `huggingface-cli download`).
MODEL_LOCAL_DIR = os.environ.get("MODEL_LOCAL_DIR", None)

def _resolve_model_path(hub_id: str, local_subdir: str | None = None) -> str:
    """Resolve model path: local dir > HF cache > HF Hub ID."""
    # 1. Explicit local directory
    if MODEL_LOCAL_DIR:
        local_path = os.path.join(MODEL_LOCAL_DIR, local_subdir or hub_id.split("/")[-1])
        if os.path.isdir(local_path):
            return local_path
    # 2. HF cache (already downloaded via huggingface-cli download)
    try:
        cached = try_to_load_from_cache(hub_id, "config.json")
    except Exception:
        cached = None
    if cached is not None and isinstance(cached, str):
        # Return the repo snapshot directory (parent of config.json)
        return os.path.dirname(cached)
    # 3. Fallback to Hub ID (will download on first use)
    return hub_id

MEDGEMMA_4B_MODEL_ID = _resolve_model_path(_MEDGEMMA_4B_HUB_ID, "medgemma-4b")
MEDGEMMA_27B_MODEL_ID = _resolve_model_path(_MEDGEMMA_27B_HUB_ID, "medgemma-27b")
MEDSIGLIP_MODEL_ID = _resolve_model_path(_MEDSIGLIP_HUB_ID, "medsiglip-448")
MEDASR_MODEL_ID = _resolve_model_path(_MEDASR_HUB_ID, "medasr")

# --- Generation Parameters ---
MAX_NEW_TOKENS_4B = 4096
MAX_NEW_TOKENS_27B = 6000
TEMPERATURE = 0.0
REPETITION_PENALTY = 1.2  # Prevent greedy decoding repetition loops

# --- Device ---
DEVICE = os.environ.get("DEVICE", "cuda")

# --- Demo Cases Directory ---
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
DEMO_CASES_DIR = os.path.join(DATA_DIR, "demo_cases")