{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "2d243f81", "metadata": {}, "outputs": [], "source": [ "import argparse\n", "import random\n", "import os\n", "import subprocess\n", "import shutil\n", "\n", "from PIL import Image\n", "from huggingface_hub import login\n", "from utils import MAX_LENGTH, prepare_model_and_tokenizer\n", "from visual_utils.parser import CADparser, write_obj_sample\n", "from IPython.display import clear_output" ] }, { "cell_type": "markdown", "id": "b98812ed", "metadata": {}, "source": [ "### Initializing model and arguments" ] }, { "cell_type": "code", "execution_count": 49, "id": "df625563", "metadata": {}, "outputs": [], "source": [ "parser = argparse.ArgumentParser()\n", "# parser.add_argument(\"--model-name\", type=str, default=\"llama3\")\n", "parser.add_argument(\"--device-map\", type=str, default='auto')\n", "parser.add_argument(\"--lora-rank\", type=int, default=32)\n", "parser.add_argument(\"--lora-alpha\", type=int, default=32)\n", "parser.add_argument(\"--lora-dropout\", type=float, default=0.05)\n", "parser.add_argument(\"--pretrained-path\", type=str, required=True)\n", "parser.add_argument(\"--top-p\", type=float, default=0.9)\n", "parser.add_argument(\"--temperature\", type=float, default=0.9)\n", "\n", "arguments = ['--pretrained-path', '/home/v-wangruiyu/repos/CADFusion/exp/model_ckpt/CADFusion_v1_1', '--temperature', '0.3']\n", "args = parser.parse_args(arguments)" ] }, { "cell_type": "code", "execution_count": null, "id": "5624f320", "metadata": {}, "outputs": [], "source": [ "login() # put your own hf token to access llama\n", "random.seed(0)\n", "model, tokenizer = prepare_model_and_tokenizer(args)\n", "model.eval()\n", "clear_output()" ] }, { "cell_type": "markdown", "id": "86b9cb09", "metadata": {}, "source": [ "### Custom prompting" ] }, { "cell_type": "code", "execution_count": 180, "id": "db06d560", "metadata": {}, "outputs": [], "source": [ "description = input(\"Please input a description of a 3D shape: \")\n", "# description = 'The 3D shape is a cylinder.'\n", "\n", "prompt = 'Below is a description of a 3D shape:\\n'\n", "prompt += description\n", "prompt += '\\nGenerate a Computer-Aided Design(CAD) command sequence of the 3D shape:\\n'" ] }, { "cell_type": "markdown", "id": "bb16f861", "metadata": {}, "source": [ "### Inference and rendering" ] }, { "cell_type": "markdown", "id": "59c5f38e", "metadata": {}, "source": [ "#### Model Inference" ] }, { "cell_type": "code", "execution_count": 181, "id": "ab5ff2e8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n" ] }, { "data": { "text/plain": [ "'circle,31,53,31,9,53,31,9,31 circle,31,51,31,11,51,31,11,31 circle,31,51,31,11,51,31,11,31 add,0,62,31,31,31,1,0,0,0,0,1,0,-1,0,7,31,31 '" ] }, "execution_count": 181, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = tokenizer(\n", " prompt,\n", " return_tensors=\"pt\",\n", ")\n", "batch = {k: v.cuda() for k, v in batch.items()}\n", "\n", "generate_ids = model.generate(\n", " **batch,\n", " do_sample=True,\n", " max_new_tokens=MAX_LENGTH,\n", " temperature=args.temperature,\n", " top_p=args.top_p,\n", " repetition_penalty=1.3,\n", ")\n", "\n", "gen_strs = tokenizer.batch_decode(\n", " generate_ids,\n", " skip_special_tokens=True,\n", " clean_up_tokenization_spaces=False,\n", ")\n", "gen_strs = gen_strs[0][len(prompt):]\n", "gen_strs" ] }, { "cell_type": "markdown", "id": "f56d6fcf", "metadata": {}, "source": [ "#### Render .obj file" ] }, { "cell_type": "code", "execution_count": 182, "id": "95498ccb", "metadata": {}, "outputs": [], "source": [ "out_path = 'visual_cache/gen_obj'\n", "# remove the existing output directory if it exists\n", "if os.path.exists(out_path):\n", " shutil.rmtree(out_path)\n", "# create the output directory\n", "os.makedirs(out_path, exist_ok=True)\n", "\n", "cad_parser = CADparser(bit=6)\n", "parsed_data = cad_parser.perform(gen_strs)\n", "write_obj_sample(out_path, parsed_data)" ] }, { "cell_type": "markdown", "id": "79b5dfaf", "metadata": {}, "source": [ "#### Render .step, .stl, .ply files\n", "N.B. if the Statistics on Transfer logs do not show up, the model may not have produced renderable outputs. Re-run the inference or change your prompt to see if it gets better results. " ] }, { "cell_type": "code", "execution_count": null, "id": "8a49694f", "metadata": {}, "outputs": [], "source": [ "out_path = os.path.abspath(out_path)\n", "py_path = os.path.abspath('../rendering_utils/parser_visual.py')\n", "subprocess.run(['python3', py_path, '--data_folder', out_path, '--single-file'])\n", "py_path = os.path.abspath('../rendering_utils/ptl_sampler.py')\n", "subprocess.run(['python3', py_path, '--in_dir', out_path, '--out_dir', 'ptl', '--single-file'])\n", "# clear_output()" ] }, { "cell_type": "markdown", "id": "0e0f1fd1", "metadata": {}, "source": [ "#### Image rendering" ] }, { "cell_type": "code", "execution_count": null, "id": "586f3a91", "metadata": {}, "outputs": [], "source": [ "visual_obj_path = 'visual_cache'\n", "output_figure_path = 'visual_cache/figures'\n", "if os.path.exists(output_figure_path):\n", " shutil.rmtree(output_figure_path)\n", "py_path = os.path.abspath('../rendering_utils/img_renderer.py')\n", "os.makedirs(output_figure_path, exist_ok=True)\n", "try:\n", " xvfb_process = subprocess.Popen(\n", " [\"Xvfb\", \":99\", \"-screen\", \"0\", \"640x480x24\"],\n", " stdout=subprocess.DEVNULL,\n", " stderr=subprocess.DEVNULL\n", " )\n", " print(\"Xvfb started in the background.\")\n", "except FileNotFoundError:\n", " print(\"Error: Xvfb not found. Please ensure it is installed and in your system's PATH.\")\n", "\n", "os.environ['DISPLAY'] = ':99'\n", "try:\n", " subprocess.run(\n", " ['python3', py_path, '--input_dir', visual_obj_path, '--output_dir', output_figure_path]\n", " )\n", " print(\"Rendering script completed successfully.\")\n", "finally:\n", " if xvfb_process.poll() is None: # Check if Xvfb is still running\n", " xvfb_process.terminate()\n", " print(\"Xvfb terminated.\")\n", " else:\n", " print(\"Xvfb already exited.\")\n", " \n", "del os.environ['DISPLAY']\n", "clear_output()\n", "\n", "input_image_path = os.path.join(output_figure_path, 'gen_ob.png')\n", "if os.path.exists(input_image_path):\n", " img = Image.open(input_image_path)\n", " img.show()\n", "else:\n", " print(f\"{input_image_path} does not exist.\")" ] }, { "cell_type": "markdown", "id": "c78fed0f", "metadata": {}, "source": [ "#### Files retrieval\n", "By default, the produced step, stl, obj and ply files are stored under the visual_cache folder. You can save them to your custom places for further use. Do not put them in the cache folder as they will be deleted after the next run." ] } ], "metadata": { "kernelspec": { "display_name": "cdfs", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.23" } }, "nbformat": 4, "nbformat_minor": 5 }