{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the model" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using Device: cuda:1 | Name: NVIDIA RTX A6000\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import transformer\n", "from transformer import Transformer\n", "\n", "model = Transformer(embedding_dim=64, learning_rate=1e-3).to(device=transformer.device)\n", "model.load_state_dict(\n", " torch.load(\n", " \"./single_headed_transformer_v4_weights.pth\", \n", " map_location=torch.device(\"cuda:1\"), \n", " weights_only=True\n", " )\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inferencing" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "' Symptoms include difficulty shifting gears slippage or a burning smell when proper cleans overheating '" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input = \"identify my car's clutch is going bad\"\n", "encoder_input_text = model.remove_punctuation(input)\n", "encoder_input_tokens = model.tokenize(encoder_input_text)\n", "\n", "model.predict_text(encoder_input_tokens)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }