pbanavara commited on
Commit
9cc586f
·
verified ·
1 Parent(s): 75a4eab

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. prana_grpo_qwen35_9b.ipynb +542 -0
  3. server/requirements.txt +1 -1
  4. setup.py +11 -0
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
5
  colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
- app_port: 8000
9
  base_path: /web
10
  tags:
11
  - openenv
 
5
  colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
+ app_port: 7860
9
  base_path: /web
10
  tags:
11
  - openenv
prana_grpo_qwen35_9b.ipynb ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# PRANA-Env: Reinforcement Learning with Qwen3.5-9B\n",
8
+ "\n",
9
+ "Fine-tune **Qwen3.5-9B** using **GRPO** on the PRANA kidney transplant administration environment.\n",
10
+ "\n",
11
+ "The agent must:\n",
12
+ "1. Query fragmented clinical datastores\n",
13
+ "2. Detect stale lab values (90-day KARS recency window)\n",
14
+ "3. Detect anomalous measurements (>25% change within 14 days)\n",
15
+ "4. File a complete KARS-compliant SRTR report\n",
16
+ "\n",
17
+ "Reward signal comes from the deterministic KARS validator in prana_env.\n",
18
+ "\n",
19
+ "**Hardware**: H100 80GB recommended. BF16 LoRA, no 4-bit quantization.\n",
20
+ "\n",
21
+ "**Baseline**: Qwen3:8b untuned scores **0.71 Pass@1** on temporal/anomaly tasks. \n",
22
+ "**Target**: ≥ 0.90 Pass@1 after GRPO fine-tuning."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "metadata": {},
28
+ "source": [
29
+ "## 1. Installation"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "%%capture\n",
39
+ "import os, importlib.util\n",
40
+ "!pip install --upgrade -qqq uv\n",
41
+ "if importlib.util.find_spec('torch') is None or 'COLAB_' in ''.join(os.environ.keys()):\n",
42
+ " try: import numpy; get_numpy = f'numpy=={numpy.__version__}'\n",
43
+ " except: get_numpy = 'numpy'\n",
44
+ " !uv pip install -qqq \\\n",
45
+ " 'torch>=2.8.0' 'triton>=3.4.0' {get_numpy} torchvision bitsandbytes 'transformers==4.56.2' \\\n",
46
+ " 'unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo' \\\n",
47
+ " 'unsloth[base] @ git+https://github.com/unslothai/unsloth'\n",
48
+ "elif importlib.util.find_spec('unsloth') is None:\n",
49
+ " !uv pip install -qqq unsloth\n",
50
+ "!uv pip install --upgrade --no-deps transformers==4.56.2 tokenizers trl==0.22.2 unsloth unsloth_zoo"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "%%capture\n",
60
+ "# Clone prana_env and install dependencies\n",
61
+ "!git clone https://github.com/pbanavara/prana_env.git\n",
62
+ "!uv pip install -qqq fastapi uvicorn websockets pydantic openenv requests\n",
63
+ "%cd prana_env\n",
64
+ "!uv pip install -qqq -e .\n",
65
+ "\n",
66
+ "import sys, os\n",
67
+ "sys.path.insert(0, '.')\n",
68
+ "working_directory = os.path.abspath('.')"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "metadata": {},
74
+ "source": [
75
+ "## 2. Load Qwen3.5-9B with LoRA"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "from unsloth import FastLanguageModel\n",
85
+ "import torch\n",
86
+ "\n",
87
+ "max_seq_length = 2048 # Multi-turn clinical dialogue needs longer context\n",
88
+ "lora_rank = 16 # Higher rank than 2048-game notebook — more complex reasoning task\n",
89
+ "\n",
90
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
91
+ " model_name = 'unsloth/Qwen3.5-9B',\n",
92
+ " load_in_4bit = False, # BF16 — QLoRA not recommended for Qwen3.5\n",
93
+ " max_seq_length = max_seq_length,\n",
94
+ ")"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "model = FastLanguageModel.get_peft_model(\n",
104
+ " model,\n",
105
+ " r = lora_rank,\n",
106
+ " target_modules = [\n",
107
+ " 'q_proj', 'k_proj', 'v_proj', 'o_proj',\n",
108
+ " 'gate_proj', 'up_proj', 'down_proj',\n",
109
+ " ],\n",
110
+ " lora_alpha = lora_rank * 2,\n",
111
+ " use_gradient_checkpointing = 'unsloth',\n",
112
+ " random_state = 3407,\n",
113
+ ")"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "## 3. Launch prana_env server\n",
121
+ "\n",
122
+ "Start the FastAPI + WebSocket server as a local subprocess — same pattern as the OpenEnv 2048 notebook."
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "import subprocess, time, requests\n",
132
+ "\n",
133
+ "PRANA_PORT = 8000\n",
134
+ "PRANA_BASE_URL = f'http://localhost:{PRANA_PORT}'\n",
135
+ "_server_proc = None\n",
136
+ "\n",
137
+ "def launch_prana_server():\n",
138
+ " global _server_proc\n",
139
+ " if _server_proc is not None:\n",
140
+ " try:\n",
141
+ " requests.get(f'{PRANA_BASE_URL}/health', timeout=2)\n",
142
+ " return # already running\n",
143
+ " except Exception:\n",
144
+ " _server_proc.kill()\n",
145
+ " _server_proc = None\n",
146
+ "\n",
147
+ " _server_proc = subprocess.Popen(\n",
148
+ " ['uvicorn', 'server.app:app', '--host', '0.0.0.0', '--port', str(PRANA_PORT)],\n",
149
+ " cwd=working_directory,\n",
150
+ " stdout=subprocess.DEVNULL,\n",
151
+ " stderr=subprocess.DEVNULL,\n",
152
+ " )\n",
153
+ " # Wait for server to be ready\n",
154
+ " for _ in range(20):\n",
155
+ " try:\n",
156
+ " requests.get(f'{PRANA_BASE_URL}/health', timeout=2)\n",
157
+ " print(f'prana_env server ready at {PRANA_BASE_URL}')\n",
158
+ " return\n",
159
+ " except Exception:\n",
160
+ " time.sleep(1)\n",
161
+ " raise RuntimeError('prana_env server failed to start')\n",
162
+ "\n",
163
+ "launch_prana_server()"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "metadata": {},
169
+ "source": [
170
+ "## 4. PRANA-Env client helpers\n",
171
+ "\n",
172
+ "Thin wrappers around the WebSocket client for use in the reward function."
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "import random\n",
182
+ "from prana_env.client import PranaEnv\n",
183
+ "from prana_env.models import PranaAction\n",
184
+ "\n",
185
+ "PATIENTS = ['P001', 'P002', 'P003']\n",
186
+ "\n",
187
+ "def run_episode(action_sequence: list[dict]) -> tuple[float, str]:\n",
188
+ " \"\"\"\n",
189
+ " Execute a list of parsed actions against prana_env and return (reward, kars_result).\n",
190
+ " action_sequence: list of dicts with keys matching PranaAction fields.\n",
191
+ " Returns (cumulative_reward, 'PASSED'|'FAILED'|'INCOMPLETE').\n",
192
+ " \"\"\"\n",
193
+ " launch_prana_server()\n",
194
+ " patient_id = random.choice(PATIENTS)\n",
195
+ " cumulative_reward = 0.0\n",
196
+ " kars_result = 'INCOMPLETE'\n",
197
+ "\n",
198
+ " with PranaEnv(base_url=PRANA_BASE_URL) as env:\n",
199
+ " obs = env.reset(patient_id=patient_id)\n",
200
+ " for action_dict in action_sequence:\n",
201
+ " try:\n",
202
+ " action = PranaAction(**action_dict)\n",
203
+ " result = env.step(action)\n",
204
+ " cumulative_reward += result.reward\n",
205
+ " if result.done:\n",
206
+ " kars_result = result.observation.kars_result or 'FAILED'\n",
207
+ " break\n",
208
+ " except Exception:\n",
209
+ " cumulative_reward -= 1.0 # malformed action penalty\n",
210
+ " continue\n",
211
+ "\n",
212
+ " return cumulative_reward, kars_result"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "markdown",
217
+ "metadata": {},
218
+ "source": [
219
+ "## 5. Action parser\n",
220
+ "\n",
221
+ "The model outputs a structured action sequence in its response. We parse it into PranaAction dicts."
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "import json, re\n",
231
+ "\n",
232
+ "def extract_actions(response: str) -> list[dict]:\n",
233
+ " \"\"\"\n",
234
+ " Extract a JSON array of actions from the model response.\n",
235
+ " The model is instructed to output actions inside ```json ... ``` blocks.\n",
236
+ " \"\"\"\n",
237
+ " # Try fenced JSON block first\n",
238
+ " match = re.search(r'```json\\s*(\\[.*?\\])\\s*```', response, re.DOTALL)\n",
239
+ " if not match:\n",
240
+ " # Fallback: any JSON array in the response\n",
241
+ " match = re.search(r'(\\[\\s*\\{.*?\\}\\s*\\])', response, re.DOTALL)\n",
242
+ " if not match:\n",
243
+ " return []\n",
244
+ " try:\n",
245
+ " return json.loads(match.group(1))\n",
246
+ " except json.JSONDecodeError:\n",
247
+ " return []"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {},
253
+ "source": [
254
+ "## 6. GRPO prompt\n",
255
+ "\n",
256
+ "The prompt instructs the model to output a full action plan for KARS filing."
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "SYSTEM_PROMPT = \"\"\"\n",
266
+ "You are a clinical administrative agent for a kidney transplant center.\n",
267
+ "Your task is to file a KARS-compliant SRTR report for a patient.\n",
268
+ "\n",
269
+ "Today's date is 2026-03-07 (filing date T5).\n",
270
+ "The patient has a record from approximately 4 months ago (T1). Some fields may be stale.\n",
271
+ "\n",
272
+ "KARS Required Fields:\n",
273
+ "- hba1c, gfr, creatinine (PatientDB) — time-sensitive, must be within 90 days of filing\n",
274
+ "- blood_type (PatientDB) — stable, no recency requirement\n",
275
+ "\n",
276
+ "OPTN Clinical Integrity Policy:\n",
277
+ "- If two measurements of the same field within 14 days differ by >25%, do NOT file.\n",
278
+ " Communicate the anomaly and recommend a confirmatory test.\n",
279
+ "\n",
280
+ "Actions available:\n",
281
+ "- query_db: {action_type: query_db, target: PatientDB, field: <field>, patient_id: <id>}\n",
282
+ "- record_value: {action_type: record_value, field: <field>, value: <value>}\n",
283
+ "- file_report: {action_type: file_report}\n",
284
+ "\n",
285
+ "Output your complete action plan as a JSON array inside ```json ... ``` tags.\n",
286
+ "Reason step by step before outputting actions.\n",
287
+ "\"\"\".strip()\n",
288
+ "\n",
289
+ "USER_PROMPT_TEMPLATE = \"\"\"\n",
290
+ "File a KARS-compliant SRTR report for patient {patient_id}.\n",
291
+ "The T1 snapshot from ~4 months ago is pre-loaded in the record.\n",
292
+ "Check which fields are stale or anomalous, re-query only what is needed, and file.\n",
293
+ "\"\"\".strip()\n",
294
+ "\n",
295
+ "def make_prompt(patient_id: str) -> list[dict]:\n",
296
+ " return [\n",
297
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
298
+ " {\"role\": \"user\", \"content\": USER_PROMPT_TEMPLATE.format(patient_id=patient_id)},\n",
299
+ " ]"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "markdown",
304
+ "metadata": {},
305
+ "source": [
306
+ "## 7. Reward functions\n",
307
+ "\n",
308
+ "Three reward signals fed to GRPOTrainer:\n",
309
+ "1. `actions_parseable` — model output is valid JSON with recognizable actions\n",
310
+ "2. `kars_reward` — KARS validator reward from prana_env (+15 first pass, +10 after correction, -5 fail)"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "def actions_parseable(completions, **kwargs):\n",
320
+ " \"\"\"Reward 1.0 if the model outputs a parseable action list, -1.0 otherwise.\"\"\"\n",
321
+ " scores = []\n",
322
+ " for completion in completions:\n",
323
+ " response = completion[0]['content']\n",
324
+ " actions = extract_actions(response)\n",
325
+ " scores.append(1.0 if len(actions) > 0 else -1.0)\n",
326
+ " return scores\n",
327
+ "\n",
328
+ "\n",
329
+ "def kars_reward(completions, prompts, **kwargs):\n",
330
+ " \"\"\"\n",
331
+ " Execute the action sequence in prana_env and return the KARS reward.\n",
332
+ " Reward scale mirrors prana_env:\n",
333
+ " +15 KARS PASSED first attempt\n",
334
+ " +10 KARS PASSED after correction\n",
335
+ " -1 re-query of already-fresh field\n",
336
+ " -5 KARS FAILED\n",
337
+ " -10 unrecoverable (3 attempts)\n",
338
+ " Normalized to [-1, 1] for GRPO stability.\n",
339
+ " \"\"\"\n",
340
+ " scores = []\n",
341
+ " for completion, prompt in zip(completions, prompts):\n",
342
+ " response = completion[0]['content']\n",
343
+ " actions = extract_actions(response)\n",
344
+ "\n",
345
+ " if not actions:\n",
346
+ " scores.append(-1.0)\n",
347
+ " continue\n",
348
+ "\n",
349
+ " # Extract patient_id from the user message\n",
350
+ " patient_id = 'P001'\n",
351
+ " for msg in prompt:\n",
352
+ " if msg['role'] == 'user':\n",
353
+ " m = re.search(r'P00\\d', msg['content'])\n",
354
+ " if m:\n",
355
+ " patient_id = m.group(0)\n",
356
+ "\n",
357
+ " # Inject patient_id into query_db actions if missing\n",
358
+ " for a in actions:\n",
359
+ " if a.get('action_type') == 'query_db' and 'patient_id' not in a:\n",
360
+ " a['patient_id'] = patient_id\n",
361
+ "\n",
362
+ " try:\n",
363
+ " raw_reward, kars_result = run_episode(actions)\n",
364
+ " # Normalize: max raw reward is +15, min is -10\n",
365
+ " normalized = max(-1.0, min(1.0, raw_reward / 15.0))\n",
366
+ " scores.append(normalized)\n",
367
+ " print(f'[KARS] patient={patient_id} result={kars_result} raw={raw_reward:.1f} normalized={normalized:.2f}')\n",
368
+ " except Exception as e:\n",
369
+ " print(f'[KARS] error: {e}')\n",
370
+ " scores.append(-1.0)\n",
371
+ "\n",
372
+ " return scores"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "markdown",
377
+ "metadata": {},
378
+ "source": [
379
+ "## 8. Dataset\n",
380
+ "\n",
381
+ "Rotate across all 3 patients to ensure the model generalizes, not memorizes."
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": null,
387
+ "metadata": {},
388
+ "outputs": [],
389
+ "source": [
390
+ "from datasets import Dataset\n",
391
+ "\n",
392
+ "# 1000 episodes cycling through all patients\n",
393
+ "records = []\n",
394
+ "for i in range(1000):\n",
395
+ " pid = PATIENTS[i % len(PATIENTS)]\n",
396
+ " records.append({\n",
397
+ " 'prompt': make_prompt(pid),\n",
398
+ " 'answer': 0,\n",
399
+ " 'enable_thinking': False, # Qwen3.5 thinking flag (vs reasoning_effort in gpt-oss)\n",
400
+ " })\n",
401
+ "\n",
402
+ "dataset = Dataset.from_list(records)\n",
403
+ "\n",
404
+ "maximum_length = len(tokenizer.apply_chat_template(\n",
405
+ " make_prompt('P001'),\n",
406
+ " add_generation_prompt=True,\n",
407
+ "))\n",
408
+ "print(f'Prompt token length: {maximum_length}')\n",
409
+ "dataset[0]"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "markdown",
414
+ "metadata": {},
415
+ "source": [
416
+ "## 9. GRPO Training"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "max_prompt_length = maximum_length + 1\n",
426
+ "max_completion_length = max_seq_length - max_prompt_length\n",
427
+ "\n",
428
+ "from trl import GRPOConfig, GRPOTrainer\n",
429
+ "\n",
430
+ "training_args = GRPOConfig(\n",
431
+ " temperature = 1.0,\n",
432
+ " learning_rate = 5e-5,\n",
433
+ " weight_decay = 0.001,\n",
434
+ " warmup_ratio = 0.1,\n",
435
+ " lr_scheduler_type = 'linear',\n",
436
+ " optim = 'adamw_8bit',\n",
437
+ " logging_steps = 1,\n",
438
+ " per_device_train_batch_size = 1,\n",
439
+ " gradient_accumulation_steps = 4,\n",
440
+ " num_generations = 8, # Full GRPO batch — H100 80GB can handle this at 9B BF16\n",
441
+ " max_prompt_length = max_prompt_length,\n",
442
+ " max_completion_length = max_completion_length,\n",
443
+ " max_steps = 600,\n",
444
+ " save_steps = 100,\n",
445
+ " report_to = 'none',\n",
446
+ " output_dir = 'outputs',\n",
447
+ ")\n",
448
+ "\n",
449
+ "trainer = GRPOTrainer(\n",
450
+ " model = model,\n",
451
+ " processing_class = tokenizer,\n",
452
+ " reward_funcs = [\n",
453
+ " actions_parseable,\n",
454
+ " kars_reward,\n",
455
+ " ],\n",
456
+ " args = training_args,\n",
457
+ " train_dataset = dataset,\n",
458
+ ")"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": null,
464
+ "metadata": {},
465
+ "outputs": [],
466
+ "source": [
467
+ "trainer.train()"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "markdown",
472
+ "metadata": {},
473
+ "source": [
474
+ "## 10. Inference — test the fine-tuned model"
475
+ ]
476
+ },
477
+ {
478
+ "cell_type": "code",
479
+ "execution_count": null,
480
+ "metadata": {},
481
+ "outputs": [],
482
+ "source": [
483
+ "from transformers import TextStreamer\n",
484
+ "\n",
485
+ "test_patient = 'P002' # has anomalous GFR/creatinine — hardest case\n",
486
+ "text = tokenizer.apply_chat_template(\n",
487
+ " make_prompt(test_patient),\n",
488
+ " tokenize=False,\n",
489
+ " add_generation_prompt=True,\n",
490
+ " enable_thinking=False,\n",
491
+ ")\n",
492
+ "\n",
493
+ "_ = model.generate(\n",
494
+ " **tokenizer(text, return_tensors='pt').to('cuda'),\n",
495
+ " temperature=1.0,\n",
496
+ " max_new_tokens=1024,\n",
497
+ " streamer=TextStreamer(tokenizer, skip_prompt=False),\n",
498
+ ")"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "markdown",
503
+ "metadata": {},
504
+ "source": [
505
+ "## 11. Save model"
506
+ ]
507
+ },
508
+ {
509
+ "cell_type": "code",
510
+ "execution_count": null,
511
+ "metadata": {},
512
+ "outputs": [],
513
+ "source": [
514
+ "# Save LoRA adapters\n",
515
+ "model.save_pretrained('prana_qwen35_9b_lora')\n",
516
+ "tokenizer.save_pretrained('prana_qwen35_9b_lora')\n",
517
+ "\n",
518
+ "# Push to Hub (optional)\n",
519
+ "if False:\n",
520
+ " model.push_to_hub_merged(\n",
521
+ " 'pbanavara/prana-qwen35-9b-grpo',\n",
522
+ " tokenizer,\n",
523
+ " save_method='merged_16bit',\n",
524
+ " token='hf_...',\n",
525
+ " )"
526
+ ]
527
+ }
528
+ ],
529
+ "metadata": {
530
+ "kernelspec": {
531
+ "display_name": "Python 3",
532
+ "language": "python",
533
+ "name": "python3"
534
+ },
535
+ "language_info": {
536
+ "name": "python",
537
+ "version": "3.12.0"
538
+ }
539
+ },
540
+ "nbformat": 4,
541
+ "nbformat_minor": 4
542
+ }
server/requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- openenv[core]>=0.2.0
2
  fastapi>=0.115.0
3
  uvicorn>=0.24.0
4
 
 
1
+ openenv-core[core]>=0.2.0
2
  fastapi>=0.115.0
3
  uvicorn>=0.24.0
4
 
setup.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from setuptools.command.editable_wheel import editable_wheel
3
+
4
+
5
+ class CompatEditableWheel(editable_wheel):
6
+ def run(self):
7
+ self.mode = "compat"
8
+ super().run()
9
+
10
+
11
+ setup(cmdclass={"editable_wheel": CompatEditableWheel})