Show that model did train
Browse files
GPT2_Linear_4bit_training.ipynb
CHANGED
|
@@ -822,6 +822,48 @@
|
|
| 822 |
}
|
| 823 |
]
|
| 824 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
{
|
| 826 |
"cell_type": "markdown",
|
| 827 |
"source": [
|
|
|
|
| 822 |
}
|
| 823 |
]
|
| 824 |
},
|
| 825 |
+
{
|
| 826 |
+
"cell_type": "code",
|
| 827 |
+
"source": [
|
| 828 |
+
"inputs = {k:v.cuda() for k,v in tokenizer(\"\"\"\n",
|
| 829 |
+
"You are an AI assistant. You will be given a question. You must generate a short and factual answer.\n",
|
| 830 |
+
"What is the capital city of France?\n",
|
| 831 |
+
"\"\"\", return_tensors='pt').items()}\n",
|
| 832 |
+
"outputs = model.generate(**inputs, max_new_tokens=16, temperature=0.5, do_sample=True)\n",
|
| 833 |
+
"print(tokenizer.decode(outputs[0]), \"...\")"
|
| 834 |
+
],
|
| 835 |
+
"metadata": {
|
| 836 |
+
"colab": {
|
| 837 |
+
"base_uri": "https://localhost:8080/"
|
| 838 |
+
},
|
| 839 |
+
"id": "wr6bfZ0wyk3c",
|
| 840 |
+
"outputId": "ca4ede1b-7456-43ea-ce52-961c1383dff8"
|
| 841 |
+
},
|
| 842 |
+
"execution_count": 16,
|
| 843 |
+
"outputs": [
|
| 844 |
+
{
|
| 845 |
+
"output_type": "stream",
|
| 846 |
+
"name": "stderr",
|
| 847 |
+
"text": [
|
| 848 |
+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
|
| 849 |
+
]
|
| 850 |
+
},
|
| 851 |
+
{
|
| 852 |
+
"output_type": "stream",
|
| 853 |
+
"name": "stdout",
|
| 854 |
+
"text": [
|
| 855 |
+
"\n",
|
| 856 |
+
"You are an AI assistant. You will be given a question. You must generate a short and factual answer.\n",
|
| 857 |
+
"What is the capital city of France?\n",
|
| 858 |
+
"\n",
|
| 859 |
+
"\n",
|
| 860 |
+
"Paris\n",
|
| 861 |
+
"\n",
|
| 862 |
+
"Paris is the capital of France. The city is located ...\n"
|
| 863 |
+
]
|
| 864 |
+
}
|
| 865 |
+
]
|
| 866 |
+
},
|
| 867 |
{
|
| 868 |
"cell_type": "markdown",
|
| 869 |
"source": [
|