tyraepaul commited on
Commit
6d1883e
·
verified ·
1 Parent(s): 08d1e82

Upload 3 files

Browse files

stok-0.4 (no stok-0.4-large... yet.)

Files changed (4) hide show
  1. .gitattributes +2 -0
  2. run_stok.py +144 -1
  3. stok-0.4-mini.json +3 -0
  4. stok-0.4.json +3 -0
.gitattributes CHANGED
@@ -38,3 +38,5 @@ stok-0.3.json filter=lfs diff=lfs merge=lfs -text
38
  stok-0.2.json filter=lfs diff=lfs merge=lfs -text
39
  stok-0.3-125m.json filter=lfs diff=lfs merge=lfs -text
40
  stok-0.3.1.json filter=lfs diff=lfs merge=lfs -text
 
 
 
38
  stok-0.2.json filter=lfs diff=lfs merge=lfs -text
39
  stok-0.3-125m.json filter=lfs diff=lfs merge=lfs -text
40
  stok-0.3.1.json filter=lfs diff=lfs merge=lfs -text
41
+ stok-0.4-mini.json filter=lfs diff=lfs merge=lfs -text
42
+ stok-0.4.json filter=lfs diff=lfs merge=lfs -text
run_stok.py CHANGED
@@ -28,9 +28,146 @@ def strip_text(prompt): # kinda wacky overall
28
  return newprompt
29
 
30
  model = {"model_data": {}}
 
31
  def load_model(filename: str):
32
  model["model_data"] = json.loads(open(filename, "r").read())
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def version_03_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
35
  tokens_generated = 0
36
  split_prompt = strip_prompt(prompt).split(sep=None)
@@ -157,7 +294,7 @@ def version_01_inference(prompt: str, max_tokens: int=None, repetition_penalty:
157
  running = False
158
 
159
  def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temperature: float=0):
160
- # (temperature does not work on versions below 0.3)
161
  model_data = model["model_data"]
162
  model_format = model_data["format"]
163
  if model_data["format"] == "v0.1":
@@ -174,3 +311,9 @@ def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temp
174
  response = version_03_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
175
  for chunk in response:
176
  yield chunk
 
 
 
 
 
 
 
28
  return newprompt
29
 
30
  model = {"model_data": {}}
31
+
32
  def load_model(filename: str):
33
  model["model_data"] = json.loads(open(filename, "r").read())
34
 
35
+ def symbolize_prompt(prompt): # checks if prompt can be contextualized based on a symbol (currently only math)
36
+ symbols = ["+", "-", "/", "*"]
37
+ numbers = []
38
+ prompt_left = []
39
+ prompt_right = []
40
+ for x in range(0, 10):
41
+ numbers.append(str(x))
42
+ prompt = "".join(prompt.split(sep=None)) # remove whitespace
43
+ for symbol in symbols:
44
+ if symbol in prompt:
45
+ listed_prompt = list(prompt)
46
+ sym_index = listed_prompt.index(symbol)
47
+ i = sym_index
48
+ nochar = True
49
+ while nochar:
50
+ i += 1
51
+ try:
52
+ if listed_prompt[i] in numbers or listed_prompt[i] == ".":
53
+ prompt_right.append(listed_prompt[i])
54
+ else:
55
+ nochar = False
56
+ except IndexError:
57
+ nochar = False
58
+ i = sym_index
59
+ nochar = True
60
+ while nochar:
61
+ i -= 1
62
+ try:
63
+ if listed_prompt[i] in numbers or listed_prompt[i] == ".":
64
+ prompt_left.append(listed_prompt[i])
65
+ else:
66
+ nochar = False
67
+ except IndexError:
68
+ nochar = False
69
+ new_prompt = f"{''.join(prompt_left)}{symbol}{''.join(prompt_right)}"
70
+ return new_prompt
71
+ return None
72
+
73
+
74
+ def version_04_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
75
+ tokens_generated = 0
76
+ split_prompt = strip_prompt(prompt).split(sep=None)
77
+ model_data = model["model_data"]
78
+ outputs = model_data["outputs"]
79
+ raw_outputs = model_data["raw_outputs"]
80
+ prompts = model_data["prompts"]
81
+ ends = model_data["ends"]
82
+ start = ""
83
+ topic = None
84
+ for token in split_prompt:
85
+ if token in prompts:
86
+ start = max(prompts[token], key=prompts[token].get)
87
+ topic = token
88
+ break
89
+ if topic == None: # use raw outputs
90
+ save_prompt = symbolize_prompt(prompt)
91
+ if save_prompt != None:
92
+ token_now = False
93
+ for token in save_prompt.split(sep=None):
94
+ if token in prompts:
95
+ token_now = True
96
+ break
97
+ if token_now:
98
+ for chunk in version_04_inference(prompt=save_prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty):
99
+ yield chunk
100
+ else:
101
+ outputs = raw_outputs
102
+ topic = None
103
+ start = split_prompt[-1]
104
+ tokens_generated += 1
105
+ running = True
106
+ current_token = [start]
107
+ while running:
108
+ token = current_token[0]
109
+ yield f"{token} "
110
+ if token in outputs:
111
+ next_token = max(outputs[token], key=outputs[token].get)
112
+ outputs[token][next_token] -= repetition_penalty
113
+ else:
114
+ next_token = random.choice(list(outputs.keys()))
115
+ current_token[0] = next_token
116
+ tokens_generated += 1
117
+ if max_tokens != None:
118
+ if tokens_generated >= max_tokens:
119
+ running = False
120
+ if topic:
121
+ if token in ends[topic]:
122
+ running = False
123
+ else:
124
+ outputs = raw_outputs
125
+ topic = None
126
+ start = split_prompt[-1]
127
+ tokens_generated += 1
128
+ running = True
129
+ current_token = [start]
130
+ while running:
131
+ token = current_token[0]
132
+ yield f"{token} "
133
+ if token in outputs:
134
+ next_token = max(outputs[token], key=outputs[token].get)
135
+ outputs[token][next_token] -= repetition_penalty
136
+ else:
137
+ next_token = random.choice(list(outputs.keys()))
138
+ current_token[0] = next_token
139
+ tokens_generated += 1
140
+ if max_tokens != None:
141
+ if tokens_generated >= max_tokens:
142
+ running = False
143
+ if topic:
144
+ if token in ends[topic]:
145
+ running = False
146
+ else:
147
+
148
+ tokens_generated += 1
149
+ running = True
150
+ current_token = [start]
151
+ while running:
152
+ token = current_token[0]
153
+ yield f"{token} "
154
+ if outputs.get(topic) != None:
155
+ if token in outputs[topic]:
156
+ next_token = max(outputs[topic][token], key=outputs[topic][token].get)
157
+ outputs[topic][token][next_token] -= repetition_penalty
158
+ else:
159
+ next_token = random.choice(list(outputs.keys()))
160
+ current_token[0] = next_token
161
+ tokens_generated += 1
162
+ if max_tokens != None:
163
+ if tokens_generated >= max_tokens:
164
+ running = False
165
+ if topic:
166
+ if token in ends[topic]:
167
+ running = False
168
+ else:
169
+ running = False # this is because single token responses seem to break things
170
+
171
  def version_03_inference(prompt: str, max_tokens: int=None, repetition_penalty: int=2):
172
  tokens_generated = 0
173
  split_prompt = strip_prompt(prompt).split(sep=None)
 
294
  running = False
295
 
296
  def run_model(prompt: str, max_tokens: int=None, repetition_penalty: int=1, temperature: float=0):
297
+ # (temperature does not work on versions below 0.5)
298
  model_data = model["model_data"]
299
  model_format = model_data["format"]
300
  if model_data["format"] == "v0.1":
 
311
  response = version_03_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
312
  for chunk in response:
313
  yield chunk
314
+
315
+ if model_data["format"] == "v0.4":
316
+ response = version_04_inference(prompt, max_tokens=max_tokens, repetition_penalty=repetition_penalty)
317
+ for chunk in response:
318
+ yield chunk
319
+
stok-0.4-mini.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd0994c49ccb970a17d621365c0ecf26f5d0d830f039265a65a9835a32ea12c7
3
+ size 15207518
stok-0.4.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da731f68c853242c03b4334da1f8892126ae6b515596fae78a38286e01e5cfc4
3
+ size 106979287