maxxie114 Claude Sonnet 4.6 commited on
Commit
c783c53
·
1 Parent(s): 80d8c84

Add GPU/unsloth support and local scientist runtime for HF Spaces

Browse files

- Load LoRA model from SCIENTIST_HF_MODEL env var (HF Hub ID)
- Add 'local' runtime that uses fine-tuned model for /agent-step
- Install torch (CUDA 12.1) + unsloth + peft in Dockerfile

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. Dockerfile +19 -2
  2. server/app.py +33 -12
Dockerfile CHANGED
@@ -20,12 +20,29 @@ FROM python:3.11-slim
20
 
21
  WORKDIR /app
22
 
23
- # Install system deps
24
  RUN apt-get update && apt-get install -y --no-install-recommends \
25
  build-essential \
 
 
26
  && rm -rf /var/lib/apt/lists/*
27
 
28
- # Install Python dependencies first for better layer caching
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  COPY server/requirements.txt ./server/requirements.txt
30
  RUN pip install --no-cache-dir -r server/requirements.txt
31
 
 
20
 
21
  WORKDIR /app
22
 
23
+ # Install system deps (curl needed for HEALTHCHECK)
24
  RUN apt-get update && apt-get install -y --no-install-recommends \
25
  build-essential \
26
+ curl \
27
+ git \
28
  && rm -rf /var/lib/apt/lists/*
29
 
30
+ # Install PyTorch with CUDA 12.1 support (works on T4/A10 GPU Spaces;
31
+ # falls back to CPU silently if no GPU is present)
32
+ RUN pip install --no-cache-dir \
33
+ torch \
34
+ --index-url https://download.pytorch.org/whl/cu121
35
+
36
+ # Install unsloth + model-serving dependencies
37
+ RUN pip install --no-cache-dir \
38
+ unsloth \
39
+ transformers \
40
+ peft \
41
+ accelerate \
42
+ bitsandbytes \
43
+ huggingface_hub
44
+
45
+ # Install server dependencies
46
  COPY server/requirements.txt ./server/requirements.txt
47
  RUN pip install --no-cache-dir -r server/requirements.txt
48
 
server/app.py CHANGED
@@ -117,6 +117,9 @@ log = logging.getLogger("replicalab.server")
117
  # Scientist model — loaded once at startup from the GRPO checkpoint
118
  # ---------------------------------------------------------------------------
119
 
 
 
 
120
  _SCIENTIST_CHECKPOINT = os.environ.get(
121
  "SCIENTIST_CHECKPOINT",
122
  "/home/jovyan/replicalab-qwen3.5-grpo/checkpoint-200",
@@ -128,21 +131,32 @@ _scientist_ready = threading.Event() # set when load attempt completes
128
 
129
 
130
  def _load_scientist_model() -> None:
131
- """Load the fine-tuned Qwen LoRA adapter in a background thread."""
 
 
 
 
132
  global _scientist_model, _scientist_tokenizer
133
- checkpoint = Path(_SCIENTIST_CHECKPOINT)
134
- if not checkpoint.exists():
135
- log.warning(
136
- "Scientist checkpoint not found at %s — suggest endpoint will use deterministic baseline",
137
- checkpoint,
138
- )
139
- _scientist_ready.set()
140
- return
 
 
 
 
 
 
 
141
  try:
142
  from unsloth import FastLanguageModel # type: ignore
143
- log.info("Loading Scientist model from %s …", checkpoint)
144
  model, tokenizer = FastLanguageModel.from_pretrained(
145
- model_name=str(checkpoint),
146
  max_seq_length=2048,
147
  load_in_4bit=False,
148
  )
@@ -151,7 +165,7 @@ def _load_scientist_model() -> None:
151
  _scientist_tokenizer = tokenizer
152
  log.info("Scientist model loaded ✓")
153
  except Exception:
154
- log.exception("Failed to load Scientist model — suggest endpoint will use deterministic baseline")
155
  _scientist_ready.set()
156
 
157
 
@@ -777,6 +791,11 @@ def _resolve_scientist_action(session: dict[str, Any]) -> tuple[ScientistAction,
777
  runtime = get_scientist_runtime()
778
  if runtime == "baseline":
779
  action = build_baseline_scientist_action(observation.scientist)
 
 
 
 
 
780
  else:
781
  policy = _get_scientist_policy()
782
  action = policy(
@@ -795,6 +814,8 @@ def _resolve_scientist_action(session: dict[str, Any]) -> tuple[ScientistAction,
795
  if runtime == "anthropic"
796
  else get_scientist_ollama_model()
797
  if runtime == "ollama"
 
 
798
  else "baseline-heuristic"
799
  ),
800
  "scientist_action": action.model_dump(mode="json"),
 
117
  # Scientist model — loaded once at startup from the GRPO checkpoint
118
  # ---------------------------------------------------------------------------
119
 
120
+ # SCIENTIST_HF_MODEL: HuggingFace model ID (e.g. "openenv-community/replicalab-scientist-grpo-lora")
121
+ # Takes priority over SCIENTIST_CHECKPOINT local path.
122
+ _SCIENTIST_HF_MODEL = os.environ.get("SCIENTIST_HF_MODEL", "").strip()
123
  _SCIENTIST_CHECKPOINT = os.environ.get(
124
  "SCIENTIST_CHECKPOINT",
125
  "/home/jovyan/replicalab-qwen3.5-grpo/checkpoint-200",
 
131
 
132
 
133
  def _load_scientist_model() -> None:
134
+ """Load the fine-tuned Qwen LoRA adapter in a background thread.
135
+
136
+ Loads from SCIENTIST_HF_MODEL (HF Hub ID) if set, otherwise falls back
137
+ to the local SCIENTIST_CHECKPOINT path.
138
+ """
139
  global _scientist_model, _scientist_tokenizer
140
+
141
+ # Determine source: HF model ID takes priority over local path
142
+ if _SCIENTIST_HF_MODEL:
143
+ model_source = _SCIENTIST_HF_MODEL
144
+ else:
145
+ checkpoint = Path(_SCIENTIST_CHECKPOINT)
146
+ if not checkpoint.exists():
147
+ log.warning(
148
+ "Scientist checkpoint not found at %s — suggest endpoint will use deterministic baseline",
149
+ _SCIENTIST_CHECKPOINT,
150
+ )
151
+ _scientist_ready.set()
152
+ return
153
+ model_source = str(checkpoint)
154
+
155
  try:
156
  from unsloth import FastLanguageModel # type: ignore
157
+ log.info("Loading Scientist model from %s …", model_source)
158
  model, tokenizer = FastLanguageModel.from_pretrained(
159
+ model_name=model_source,
160
  max_seq_length=2048,
161
  load_in_4bit=False,
162
  )
 
165
  _scientist_tokenizer = tokenizer
166
  log.info("Scientist model loaded ✓")
167
  except Exception:
168
+ log.exception("Failed to load Scientist model from %s — suggest endpoint will use deterministic baseline", model_source)
169
  _scientist_ready.set()
170
 
171
 
 
791
  runtime = get_scientist_runtime()
792
  if runtime == "baseline":
793
  action = build_baseline_scientist_action(observation.scientist)
794
+ elif runtime == "local":
795
+ # Use fine-tuned LoRA model loaded at startup
796
+ _scientist_ready.wait(timeout=30)
797
+ scenario_pack = getattr(session.get("env"), "_scenario_pack", None)
798
+ action = _run_scientist_inference(observation.scientist, scenario_pack)
799
  else:
800
  policy = _get_scientist_policy()
801
  action = policy(
 
814
  if runtime == "anthropic"
815
  else get_scientist_ollama_model()
816
  if runtime == "ollama"
817
+ else (_SCIENTIST_HF_MODEL or _SCIENTIST_CHECKPOINT)
818
+ if runtime == "local"
819
  else "baseline-heuristic"
820
  ),
821
  "scientist_action": action.model_dump(mode="json"),