File size: 5,474 Bytes
c96ac34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6948d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import torch\n",
    "\n",
    "# I won't be adding the weights in this repo but you can pull them from huggingface\n",
    "\n",
    "# HF_PRETRAIN_DIR = Path(\"<<add your path>>/hf_pretrained\")\n",
    "# HF_SFT_DIR      = Path(\"<<add your path>>/hf_sft_merged\")\n",
    "# GGUF_Q4KM       = Path(\"<<add your path>>/model-Q4_K_M.gguf\")\n",
    "\n",
    "device = \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
    "device\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "405e0f18",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hello\n",
      "- If you want a simple introduction to the basics of a language, you can get it here.\n",
      "- Learn the basics of language with this step-by-step guide.\n",
      "- This is a great course for all those who want a basic introduction to the basics of a language.\n",
      "- You can get the entire lesson in the book.\n",
      "- You will get a complete reference on all the parts of a language.\n",
      "- It is an easy to read, easy to understand course.\n",
      "- You will learn the basics of English from a very beginner level.\n",
      "- It is an excellent course for all those who want to learn basic English.\n",
      "- It is very comprehensive.\n",
      "- This is a great course for those who want to learn to read and write English.\n",
      "- It is very good for beginners, but beginners can also get it with a simple explanation.\n",
      "- This course is very comprehensive, and it is a great course for all those who want to learn to\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "\n",
    "tok_pre = AutoTokenizer.from_pretrained(HF_PRETRAIN_DIR, use_fast=True)\n",
    "\n",
    "\n",
    "if tok_pre.pad_token_id is None:\n",
    "    tok_pre.pad_token = tok_pre.eos_token\n",
    "\n",
    "dtype = torch.float16 if device == \"mps\" else torch.float32\n",
    "model_pre = AutoModelForCausalLM.from_pretrained(HF_PRETRAIN_DIR, torch_dtype=dtype)\n",
    "model_pre.to(device)\n",
    "model_pre.eval()\n",
    "\n",
    "def gen_hf(model, tok, prompt: str, max_new_tokens=200, temperature=0.8, top_p=0.95):\n",
    "    inputs = tok(prompt, return_tensors=\"pt\")\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    with torch.no_grad():\n",
    "        out = model.generate(\n",
    "            **inputs,\n",
    "            max_new_tokens=max_new_tokens,\n",
    "            do_sample=True,\n",
    "            temperature=temperature,\n",
    "            top_p=top_p,\n",
    "            pad_token_id=tok.eos_token_id,\n",
    "            eos_token_id=tok.eos_token_id,\n",
    "        )\n",
    "    return tok.decode(out[0], skip_special_tokens=True)\n",
    "\n",
    "prompt = \"Hello \\n\"\n",
    "print(gen_hf(model_pre, tok_pre, prompt))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "745ee60e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hello. \n",
      "The function to calculate and display the results of an application with built-in functions. Here are some steps to help you get started:\n",
      "\n",
      "1. Define the function to be used in the application.\n",
      "2. Define the function to be used in the application.\n",
      "3. Define the function to be used in the application.\n",
      "4. Create a function to display the results of the application.\n",
      "5. Use the function to get a list of the results of the application.\n",
      "6. Use the function to calculate the value of the function based on the input.\n",
      "7. Use the function to display the results of the application based on the input.\n",
      "8. Use the function to evaluate the function to determine if it is working correctly.\n",
      "9. Use the function to perform a calculation that calculates the value of the function.\n",
      "10. Use the function to display the results of the application based on the input.\n",
      "\n",
      "With these steps, you can start\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "\n",
    "tok_sft = AutoTokenizer.from_pretrained(HF_SFT_DIR, use_fast=True)\n",
    "if tok_sft.pad_token_id is None:\n",
    "    tok_sft.pad_token = tok_sft.eos_token\n",
    "\n",
    "dtype = torch.float16 if device == \"mps\" else torch.float32\n",
    "model_sft = AutoModelForCausalLM.from_pretrained(HF_SFT_DIR, torch_dtype=dtype)\n",
    "model_sft.to(device)\n",
    "model_sft.eval()\n",
    "\n",
    "prompt = \"Hello. \\n\"\n",
    "print(gen_hf(model_sft, tok_sft, prompt))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}