{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d6fc963a", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload" ] }, { "cell_type": "code", "execution_count": 2, "id": "6cf002c0", "metadata": {}, "outputs": [], "source": [ "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 15, "id": "5e29d1c0", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "ds = torch.load(\"/Users/rohitkulkarni/Documents/projects/CellDreamer/backend/celldreamer/data/datasets/train.pt\", weights_only=False)" ] }, { "cell_type": "code", "execution_count": 16, "id": "ebe6280f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2446])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds[0][\"x_t\"].shape" ] }, { "cell_type": "code", "execution_count": 1, "id": "f9454346", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calculating stats from data matrix...\n" ] } ], "source": [ "from celldreamer.data import get_data_stats\n", "\n", "get_data_stats()" ] }, { "cell_type": "code", "execution_count": 17, "id": "8c8ff06c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded as API: https://robrokools-celldreamer-api.hf.space\n" ] }, { "data": { "text/plain": [ "array([[ 0.20221904, -0.10513306, -0.23988042, 0.1219071 , -0.31176904,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495346, 0.0982306 , 0.19570428, 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.24255574, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.07421678, -0.18519297, -0.09254473,\n", " -0.18334997, -0.19211988, -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n", " [ 0.20221904, -0.10513306, -0.23988041, 0.12190711, -0.31176903,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495346, 0.0982306 , 0.19570431, 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.24255586, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.0742168 , -0.18519297, -0.09254467,\n", " -0.18334997, -0.19211988, -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n", " [ 0.20221904, -0.10513306, -0.23988041, 0.12190713, -0.31176903,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495346, 0.0982306 , 0.19570434, 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.24255598, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.07421681, -0.18519297, -0.09254462,\n", " -0.18334997, -0.19211989, -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n", " [ 0.20221904, -0.10513306, -0.2398804 , 0.12190714, -0.31176902,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495345, 0.0982306 , 0.19570437, 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.2425561 , 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.07421683, -0.18519297, -0.09254456,\n", " -0.18334997, -0.1921199 , -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n", " [ 0.20221904, -0.10513306, -0.23988039, 0.12190716, -0.31176901,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495345, 0.0982306 , 0.1957044 , 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.24255621, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.07421684, -0.18519297, -0.0925445 ,\n", " -0.18334997, -0.1921199 , -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n", " [ 0.20221904, -0.10513306, -0.23988038, 0.12190717, -0.311769 ,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495345, 0.0982306 , 0.19570443, 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.24255633, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.07421686, -0.18519297, -0.09254444,\n", " -0.18334997, -0.19211991, -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n", " [ 0.20221904, -0.10513306, -0.23988038, 0.12190719, -0.311769 ,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495344, 0.0982306 , 0.19570446, 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.24255645, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.07421687, -0.18519297, -0.09254438,\n", " -0.18334997, -0.19211992, -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n", " [ 0.20221904, -0.10513306, -0.23988037, 0.1219072 , -0.31176899,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495344, 0.0982306 , 0.19570449, 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.24255657, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.07421689, -0.18519297, -0.09254432,\n", " -0.18334997, -0.19211993, -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n", " [ 0.20221904, -0.10513306, -0.23988036, 0.12190722, -0.31176898,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495343, 0.0982306 , 0.19570452, 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.24255669, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.0742169 , -0.18519297, -0.09254426,\n", " -0.18334997, -0.19211993, -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n", " [ 0.20221904, -0.10513306, -0.23988035, 0.12190723, -0.31176898,\n", " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n", " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n", " -0.00870946, -0.18495343, 0.0982306 , 0.19570455, 0.03290927,\n", " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n", " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n", " 0.24255681, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n", " 0.03532511, 0.0018872 , -0.07421692, -0.18519297, -0.0925442 ,\n", " -0.18334997, -0.19211994, -0.07095522, 0.08980912, 0.09272885,\n", " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041]])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from gradio_client import Client\n", "import json\n", "import numpy as np\n", "\n", "# 1. Connect to the Gradio Space\n", "# Uses the same endpoint as your Flask app\n", "client = Client(\"RobroKools/CellDreamer-API\")\n", "\n", "result_a = client.predict(\n", " input_data={\"genes\": list(np.random.rand(2446)), \"steps\": 10} # Sending as list to be safe\n", ")\n", "\n", "result_b = client.predict(\n", " input_data={\"genes\": list(np.random.rand(2446)), \"steps\": 10}\n", ")\n", "\n", "np.array(result_a[\"trajectory\"]) - np.array(result_b[\"trajectory\"])" ] } ], "metadata": { "kernelspec": { "display_name": "celldreamer", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" } }, "nbformat": 4, "nbformat_minor": 5 }