{ "cells": [ { "cell_type": "markdown", "id": "a9d34036", "metadata": {}, "source": [ "# Self-Driving Lab Inference on H100 With Unsloth\n", "\n", "This notebook loads a quantized Unsloth model, builds the same self-driving lab observation prompt used during training, generates the next structured lab action, and steps the simulator in a short closed-loop rollout similar to `run_agent.py`, but with faster 4-bit inference on H100." ] }, { "cell_type": "code", "execution_count": null, "id": "20b36e01", "metadata": {}, "outputs": [], "source": [ "%pip install -q -U torch transformers unsloth" ] }, { "cell_type": "code", "execution_count": null, "id": "bcf24a2e", "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "import torch\n", "\n", "from training_script import format_observation\n", "from training_unsloth import generate_action_with_model, load_model_artifacts\n", "from server.hackathon_environment import BioExperimentEnvironment\n", "\n", "print(\"CUDA available:\", torch.cuda.is_available())\n", "if torch.cuda.is_available():\n", " print(\"GPU:\", torch.cuda.get_device_name(0))\n", " print(\"bf16 supported:\", torch.cuda.is_bf16_supported())" ] }, { "cell_type": "code", "execution_count": null, "id": "c54f2cfd", "metadata": {}, "outputs": [], "source": [ "MODEL_PATH = \"artifacts/grpo-unsloth-output\" # or a Hugging Face repo / base model id\n", "SCENARIO_NAME = \"cardiac_disease_de\"\n", "SEED = 42\n", "\n", "tokenizer, model = load_model_artifacts(\n", " MODEL_PATH,\n", " trust_remote_code=True,\n", " max_seq_length=2048,\n", " load_in_4bit=True,\n", " prepare_for_inference=True,\n", ")\n", "\n", "env = BioExperimentEnvironment(scenario_name=SCENARIO_NAME, domain_randomise=False)\n", "obs = env.reset(seed=SEED)\n", "print(format_observation(obs)[:3000])" ] }, { "cell_type": "code", "execution_count": null, "id": "f9b25208", "metadata": {}, "outputs": [], "source": [ "result = generate_action_with_model(\n", " model,\n", " tokenizer,\n", " obs,\n", " max_new_tokens=160,\n", " temperature=0.2,\n", " top_p=0.9,\n", " do_sample=True,\n", ")\n", "\n", "print(\"Model response:\\n\")\n", "print(result[\"response_text\"])\n", "print(\"\\nParsed action:\\n\")\n", "result[\"action\"].model_dump() if result[\"action\"] is not None else None" ] }, { "cell_type": "code", "execution_count": null, "id": "c2408f52", "metadata": {}, "outputs": [], "source": [ "if result[\"action\"] is not None:\n", " next_obs = env.step(result[\"action\"])\n", " print(\"Reward:\", next_obs.reward)\n", " print(\"Done:\", next_obs.done)\n", " print(\"Violations:\", next_obs.rule_violations)\n", " print(\"Markers:\", next_obs.discovered_markers[:5])\n", " print(\"Mechanisms:\", next_obs.candidate_mechanisms[:5])\n", " if next_obs.latest_output is not None:\n", " print(\"Summary:\", next_obs.latest_output.summary)\n", " print(\"Latest data preview:\")\n", " print(json.dumps(next_obs.latest_output.data, indent=2)[:1200])\n", "else:\n", " print(\"Model output did not parse into an ExperimentAction.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "8af34f32", "metadata": {}, "outputs": [], "source": [ "# Optional short closed-loop rollout.\n", "obs = env.reset(seed=7)\n", "trajectory = []\n", "\n", "for step_idx in range(8):\n", " result = generate_action_with_model(model, tokenizer, obs, max_new_tokens=160)\n", " action = result[\"action\"]\n", " record = {\n", " \"step\": step_idx + 1,\n", " \"response_text\": result[\"response_text\"],\n", " \"action\": action.model_dump() if action is not None else None,\n", " }\n", " trajectory.append(record)\n", " if action is None:\n", " break\n", "\n", " next_obs = env.step(action)\n", " record.update({\n", " \"reward\": next_obs.reward,\n", " \"done\": next_obs.done,\n", " \"violations\": list(next_obs.rule_violations),\n", " \"latest_summary\": next_obs.latest_output.summary if next_obs.latest_output is not None else None,\n", " \"discovered_markers\": list(next_obs.discovered_markers[:5]),\n", " \"candidate_mechanisms\": list(next_obs.candidate_mechanisms[:5]),\n", " })\n", " obs = next_obs\n", " if obs.done:\n", " break\n", "\n", "trajectory" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }