File size: 5,636 Bytes
db03c40
 
 
 
ad39f2a
db03c40
 
 
 
 
ad39f2a
db03c40
 
 
ad39f2a
 
db03c40
ad39f2a
db03c40
 
ad39f2a
db03c40
 
 
ad39f2a
 
db03c40
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
 
ad39f2a
 
db03c40
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
 
ad39f2a
 
db03c40
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
 
ad39f2a
 
db03c40
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
 
ad39f2a
 
db03c40
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "a9d34036",
      "metadata": {},
      "source": [
        "# Self-Driving Lab Inference on H100 With Unsloth\n",
        "\n",
        "This notebook loads a quantized Unsloth model, builds the same self-driving lab observation prompt used during training, generates the next structured lab action, and steps the simulator in a short closed-loop rollout similar to `run_agent.py`, but with faster 4-bit inference on H100."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "20b36e01",
      "metadata": {},
      "outputs": [],
      "source": [
        "%pip install -q -U torch transformers unsloth"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bcf24a2e",
      "metadata": {},
      "outputs": [],
      "source": [
        "import json\n",
        "\n",
        "import torch\n",
        "\n",
        "from training_script import format_observation\n",
        "from training_unsloth import generate_action_with_model, load_model_artifacts\n",
        "from server.hackathon_environment import BioExperimentEnvironment\n",
        "\n",
        "print(\"CUDA available:\", torch.cuda.is_available())\n",
        "if torch.cuda.is_available():\n",
        "    print(\"GPU:\", torch.cuda.get_device_name(0))\n",
        "    print(\"bf16 supported:\", torch.cuda.is_bf16_supported())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c54f2cfd",
      "metadata": {},
      "outputs": [],
      "source": [
        "MODEL_PATH = \"artifacts/grpo-unsloth-output\"  # or a Hugging Face repo / base model id\n",
        "SCENARIO_NAME = \"cardiac_disease_de\"\n",
        "SEED = 42\n",
        "\n",
        "tokenizer, model = load_model_artifacts(\n",
        "    MODEL_PATH,\n",
        "    trust_remote_code=True,\n",
        "    max_seq_length=2048,\n",
        "    load_in_4bit=True,\n",
        "    prepare_for_inference=True,\n",
        ")\n",
        "\n",
        "env = BioExperimentEnvironment(scenario_name=SCENARIO_NAME, domain_randomise=False)\n",
        "obs = env.reset(seed=SEED)\n",
        "print(format_observation(obs)[:3000])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f9b25208",
      "metadata": {},
      "outputs": [],
      "source": [
        "result = generate_action_with_model(\n",
        "    model,\n",
        "    tokenizer,\n",
        "    obs,\n",
        "    max_new_tokens=160,\n",
        "    temperature=0.2,\n",
        "    top_p=0.9,\n",
        "    do_sample=True,\n",
        ")\n",
        "\n",
        "print(\"Model response:\\n\")\n",
        "print(result[\"response_text\"])\n",
        "print(\"\\nParsed action:\\n\")\n",
        "result[\"action\"].model_dump() if result[\"action\"] is not None else None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c2408f52",
      "metadata": {},
      "outputs": [],
      "source": [
        "if result[\"action\"] is not None:\n",
        "    next_obs = env.step(result[\"action\"])\n",
        "    print(\"Reward:\", next_obs.reward)\n",
        "    print(\"Done:\", next_obs.done)\n",
        "    print(\"Violations:\", next_obs.rule_violations)\n",
        "    print(\"Markers:\", next_obs.discovered_markers[:5])\n",
        "    print(\"Mechanisms:\", next_obs.candidate_mechanisms[:5])\n",
        "    if next_obs.latest_output is not None:\n",
        "        print(\"Summary:\", next_obs.latest_output.summary)\n",
        "        print(\"Latest data preview:\")\n",
        "        print(json.dumps(next_obs.latest_output.data, indent=2)[:1200])\n",
        "else:\n",
        "    print(\"Model output did not parse into an ExperimentAction.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8af34f32",
      "metadata": {},
      "outputs": [],
      "source": [
        "# Optional short closed-loop rollout.\n",
        "obs = env.reset(seed=7)\n",
        "trajectory = []\n",
        "\n",
        "for step_idx in range(8):\n",
        "    result = generate_action_with_model(model, tokenizer, obs, max_new_tokens=160)\n",
        "    action = result[\"action\"]\n",
        "    record = {\n",
        "        \"step\": step_idx + 1,\n",
        "        \"response_text\": result[\"response_text\"],\n",
        "        \"action\": action.model_dump() if action is not None else None,\n",
        "    }\n",
        "    trajectory.append(record)\n",
        "    if action is None:\n",
        "        break\n",
        "\n",
        "    next_obs = env.step(action)\n",
        "    record.update({\n",
        "        \"reward\": next_obs.reward,\n",
        "        \"done\": next_obs.done,\n",
        "        \"violations\": list(next_obs.rule_violations),\n",
        "        \"latest_summary\": next_obs.latest_output.summary if next_obs.latest_output is not None else None,\n",
        "        \"discovered_markers\": list(next_obs.discovered_markers[:5]),\n",
        "        \"candidate_mechanisms\": list(next_obs.candidate_mechanisms[:5]),\n",
        "    })\n",
        "    obs = next_obs\n",
        "    if obs.done:\n",
        "        break\n",
        "\n",
        "trajectory"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}