izuemon commited on
Commit
1b28abf
·
verified ·
1 Parent(s): 0018802

Update turbowarp-server/gpt.py

Browse files
Files changed (1) hide show
  1. turbowarp-server/gpt.py +81 -190
turbowarp-server/gpt.py CHANGED
@@ -1,7 +1,6 @@
1
  import scratchcommunication
2
  import time
3
  import requests
4
- import traceback
5
 
6
  PROJECT_ID = "1290918780"
7
 
@@ -23,176 +22,100 @@ slots = [f"n{i}" for i in range(1, 10)]
23
  # ---------------------
24
 
25
  chars = []
26
-
27
  with open("turbowarp-server/n-chars.txt", encoding="utf8") as f:
28
  for line in f:
29
  chars.append(line.strip())
30
 
31
- # 高速化用マップ
32
- char_map = {c: i for i, c in enumerate(chars)}
33
-
34
- # ---------------------
35
- # encode / decode
36
- # ---------------------
37
-
38
  def encode(text):
39
-
40
  out = ""
41
-
42
  for c in text:
43
- try:
44
- if c in char_map:
45
- out += f"{char_map[c]:02d}"
46
- else:
47
- print("encode skip unknown char:", c)
48
- except Exception:
49
- print("ENCODE ERROR")
50
- print(traceback.format_exc())
51
-
52
  return out
53
 
54
-
55
  def decode(data):
 
 
56
 
57
- text = ""
 
58
 
59
- try:
60
-
61
- for i in range(0, len(data), 2):
62
-
63
- part = data[i:i+2]
64
-
65
- if len(part) < 2:
66
- print("decode broken pair:", part)
67
- continue
68
 
69
- num = int(part)
 
70
 
71
- if num == 99:
72
- text += "\n"
 
73
 
74
- elif num < len(chars):
75
- text += chars[num]
76
-
77
- else:
78
- print("decode invalid index:", num)
79
-
80
- except Exception:
81
- print("DECODE ERROR")
82
- print(traceback.format_exc())
83
-
84
- return text
85
 
 
86
 
87
  # ---------------------
88
- # TwCloudConnection
89
  # ---------------------
90
 
91
  def get_var(name):
92
-
93
  try:
94
- v = tw.get_variable(name=name, name_literal=False)
95
- return v
96
-
97
  except Exception:
98
- print("get_var error:", name)
99
- print(traceback.format_exc())
100
  return None
101
 
102
-
103
  def set_var(name, value):
104
-
105
  try:
106
- tw.set_variable(name=name, value=value, name_literal=False)
107
- return True
108
-
109
  except Exception:
110
- print("set_var error:", name)
111
- print(traceback.format_exc())
112
- return False
113
-
114
 
115
  # ---------------------
116
  # n0管理
117
  # ---------------------
118
 
119
  def get_used():
120
-
121
  v = get_var("n0")
122
-
123
  if not v:
124
  return []
125
-
126
  return list(v)
127
 
128
-
129
  def add_used(i):
130
-
131
  u = get_used()
132
-
133
  if str(i) not in u:
134
  u.append(str(i))
135
-
136
  set_var("n0", "".join(u))
137
 
138
-
139
  def remove_used(i):
140
-
141
  u = get_used()
142
-
143
  if str(i) in u:
144
  u.remove(str(i))
145
-
146
  set_var("n0", "".join(u))
147
 
148
-
149
  # ---------------------
150
  # API
151
  # ---------------------
152
 
153
  def ask_gpt(history):
 
154
 
155
- try:
156
-
157
- messages = [SYSTEM_PROMPT] + history
 
 
 
 
 
158
 
159
- print("API request start")
160
-
161
- r = requests.post(
162
- "https://izuemon-phi-3.hf.space/v1/chat/completions",
163
- json={
164
- "model": "gpt-3.5-turbo",
165
- "messages": messages
166
- },
167
- timeout=30
168
- )
169
-
170
- print("API status:", r.status_code)
171
-
172
- r.raise_for_status()
173
-
174
- data = r.json()
175
-
176
- reply = data["choices"][0]["message"]["content"]
177
-
178
- print("API reply length:", len(reply))
179
-
180
- return reply
181
-
182
- except requests.exceptions.Timeout:
183
- print("API timeout")
184
- return "APIタイムアウト"
185
-
186
- except requests.exceptions.RequestException:
187
- print("API request error")
188
- print(traceback.format_exc())
189
- return "API通信エラー"
190
-
191
- except Exception:
192
- print("API unknown error")
193
- print(traceback.format_exc())
194
- return "AIエラー"
195
 
 
196
 
197
  # ---------------------
198
  # 送信
@@ -200,141 +123,109 @@ def ask_gpt(history):
200
 
201
  def send(slot, text):
202
 
203
- print("send to", slot)
204
-
205
  encoded = encode(text)
206
 
207
- if not encoded:
208
- print("encoded text empty")
209
- return
210
-
211
  size = 99996
212
 
213
  packets = [encoded[i:i+size] for i in range(0, len(encoded), size)]
214
-
215
  total = len(packets)
216
 
217
- print("packet count:", total)
218
-
219
- for index, p in enumerate(packets):
220
 
221
  packet = f"1{total}0{p}"
222
 
223
- print("send packet", index+1, "/", total)
224
-
225
  start = time.time()
226
 
227
- if not set_var(slot, packet):
228
- print("packet send failed")
229
- return
230
 
231
  while True:
232
 
233
  v = get_var(slot)
234
 
235
  if v and len(v) > 2 and v[2] == "1":
236
- print("packet confirmed")
237
  break
238
 
239
  if time.time() - start > 10:
240
- print("packet timeout")
241
  return
242
 
243
  time.sleep(0.1)
244
 
245
-
246
  # ---------------------
247
  # メインループ
248
  # ---------------------
249
 
250
  buffers = {}
251
 
252
- print("SERVER STARTED")
253
-
254
  while True:
255
 
256
- try:
257
 
258
- for i, slot in enumerate(slots, 1):
259
 
260
- v = get_var(slot)
 
261
 
262
- if not v:
263
- continue
264
 
265
- if len(v) < 3:
266
- print("invalid packet:", v)
267
- continue
268
-
269
- if v[0] != "0":
270
- continue
271
-
272
- unread = v[2] == "0"
273
-
274
- if not unread:
275
- continue
276
-
277
- print("new message slot", slot)
278
 
279
- add_used(i)
280
 
281
- try:
282
- total = int(v[1])
283
- except Exception:
284
- print("invalid total packet count")
285
- remove_used(i)
286
- continue
287
 
288
- data = v[3:]
289
 
290
- newv = v[:2] + "1" + v[3:]
291
- set_var(slot, newv)
292
-
293
- if slot not in buffers:
294
- buffers[slot] = []
295
 
296
- buffers[slot].append(data)
297
 
298
- if len(buffers[slot]) < total:
299
- continue
300
 
301
- joined = "".join(buffers[slot])
 
302
 
303
- print("joined length:", len(joined))
304
 
305
- decoded = decode(joined)
 
306
 
307
- print("decoded message:", decoded)
308
 
309
- history = []
310
 
311
- parts = decoded.split("\n")
312
 
313
- for j in range(0, len(parts), 2):
314
 
315
- if parts[j].strip() == "":
316
- continue
317
 
318
- history.append({
319
- "role": "user",
320
- "content": parts[j]
321
- })
322
 
323
- try:
324
- reply = ask_gpt(history)
325
 
326
- except Exception:
327
- print("AI call failed")
328
- reply = "AIエラー"
 
329
 
330
- send(slot, reply)
 
 
 
331
 
332
- buffers[slot] = []
333
 
334
- remove_used(i)
335
 
336
- except Exception:
337
- print("MAIN LOOP ERROR")
338
- print(traceback.format_exc())
339
 
340
  time.sleep(0.2)
 
1
  import scratchcommunication
2
  import time
3
  import requests
 
4
 
5
  PROJECT_ID = "1290918780"
6
 
 
22
  # ---------------------
23
 
24
  chars = []
 
25
  with open("turbowarp-server/n-chars.txt", encoding="utf8") as f:
26
  for line in f:
27
  chars.append(line.strip())
28
 
 
 
 
 
 
 
 
29
  def encode(text):
 
30
  out = ""
 
31
  for c in text:
32
+ if c in chars:
33
+ i = chars.index(c)
34
+ out += f"{i:02d}"
35
+ elif c == "\n":
36
+ out += "98"
 
 
 
 
37
  return out
38
 
 
39
  def decode(data):
40
+ tokens = []
41
+ current = ""
42
 
43
+ for i in range(0, len(data), 2):
44
+ num = int(data[i:i+2])
45
 
46
+ if num == 99: # 履歴区切り
47
+ tokens.append(current)
48
+ current = ""
 
 
 
 
 
 
49
 
50
+ elif num == 98: # 改行
51
+ current += "\n"
52
 
53
+ else:
54
+ if num < len(chars):
55
+ current += chars[num]
56
 
57
+ tokens.append(current)
 
 
 
 
 
 
 
 
 
 
58
 
59
+ return tokens
60
 
61
  # ---------------------
62
+ # TwCloudConnection ラッパー
63
  # ---------------------
64
 
65
  def get_var(name):
 
66
  try:
67
+ return tw.get_variable(name=name, name_literal=False)
 
 
68
  except Exception:
 
 
69
  return None
70
 
 
71
  def set_var(name, value):
 
72
  try:
73
+ return tw.set_variable(name=name, value=value, name_literal=False)
 
 
74
  except Exception:
75
+ return None
 
 
 
76
 
77
  # ---------------------
78
  # n0管理
79
  # ---------------------
80
 
81
  def get_used():
 
82
  v = get_var("n0")
 
83
  if not v:
84
  return []
 
85
  return list(v)
86
 
 
87
  def add_used(i):
 
88
  u = get_used()
 
89
  if str(i) not in u:
90
  u.append(str(i))
 
91
  set_var("n0", "".join(u))
92
 
 
93
  def remove_used(i):
 
94
  u = get_used()
 
95
  if str(i) in u:
96
  u.remove(str(i))
 
97
  set_var("n0", "".join(u))
98
 
 
99
  # ---------------------
100
  # API
101
  # ---------------------
102
 
103
  def ask_gpt(history):
104
+ messages = [SYSTEM_PROMPT] + history
105
 
106
+ r = requests.post(
107
+ "https://izuemon-phi-3.hf.space/v1/chat/completions",
108
+ json={
109
+ "model": "gpt-3.5-turbo",
110
+ "messages": messages
111
+ },
112
+ timeout=30
113
+ )
114
 
115
+ r.raise_for_status()
116
+ data = r.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ return data["choices"][0]["message"]["content"]
119
 
120
  # ---------------------
121
  # 送信
 
123
 
124
  def send(slot, text):
125
 
 
 
126
  encoded = encode(text)
127
 
 
 
 
 
128
  size = 99996
129
 
130
  packets = [encoded[i:i+size] for i in range(0, len(encoded), size)]
 
131
  total = len(packets)
132
 
133
+ for p in packets:
 
 
134
 
135
  packet = f"1{total}0{p}"
136
 
 
 
137
  start = time.time()
138
 
139
+ set_var(slot, packet)
 
 
140
 
141
  while True:
142
 
143
  v = get_var(slot)
144
 
145
  if v and len(v) > 2 and v[2] == "1":
 
146
  break
147
 
148
  if time.time() - start > 10:
 
149
  return
150
 
151
  time.sleep(0.1)
152
 
 
153
  # ---------------------
154
  # メインループ
155
  # ---------------------
156
 
157
  buffers = {}
158
 
 
 
159
  while True:
160
 
161
+ for i, slot in enumerate(slots, 1):
162
 
163
+ v = get_var(slot)
164
 
165
+ if not v:
166
+ continue
167
 
168
+ if len(v) < 3:
169
+ continue
170
 
171
+ if v[0] != "0":
172
+ continue
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ unread = v[2] == "0"
175
 
176
+ if not unread:
177
+ continue
 
 
 
 
178
 
179
+ add_used(i)
180
 
181
+ try:
182
+ total = int(v[1])
183
+ except Exception:
184
+ remove_used(i)
185
+ continue
186
 
187
+ data = v[3:]
188
 
189
+ newv = v[:2] + "1" + v[3:]
190
+ set_var(slot, newv)
191
 
192
+ if slot not in buffers:
193
+ buffers[slot] = []
194
 
195
+ buffers[slot].append(data)
196
 
197
+ if len(buffers[slot]) < total:
198
+ continue
199
 
200
+ joined = "".join(buffers[slot])
201
 
202
+ messages = decode(joined)
203
 
204
+ history = []
205
 
206
+ for j, msg in enumerate(messages):
207
 
208
+ msg = msg.strip()
 
209
 
210
+ if msg == "":
211
+ continue
 
 
212
 
213
+ role = "user" if j % 2 == 0 else "assistant"
 
214
 
215
+ history.append({
216
+ "role": role,
217
+ "content": msg
218
+ })
219
 
220
+ try:
221
+ reply = ask_gpt(history)
222
+ except Exception:
223
+ reply = "エラーが発生しました。"
224
 
225
+ send(slot, reply)
226
 
227
+ buffers[slot] = []
228
 
229
+ remove_used(i)
 
 
230
 
231
  time.sleep(0.2)