posix4e commited on
Commit
d36ebbf
·
1 Parent(s): ade1378

Switch to a library for managing openai access

Browse files

Simplify UUIDs so we don't get duplicates
Add tests

backend/backend.py CHANGED
@@ -2,12 +2,14 @@ import json
2
  import uuid
3
  from datetime import datetime
4
 
5
- import gradio as gr
6
  import mistune
7
  import openai
8
  from dotenv import load_dotenv
 
9
  from fastapi import FastAPI, HTTPException
 
10
  from fastapi.testclient import TestClient
 
11
  from pydantic import BaseModel
12
  from pygments import highlight
13
  from pygments.formatters import html
@@ -17,13 +19,10 @@ from sqlalchemy.ext.declarative import declarative_base
17
  from sqlalchemy.orm import Session, sessionmaker
18
  from sqlalchemy.sql import insert, select, text
19
  from uvicorn import Config, Server
20
- from fastapi.middleware.gzip import GZipMiddleware
21
 
22
  LANGS = [
23
- "text-davinci-002:100",
24
- "text-davinci-003:1500",
25
- "gpt-3.5-turbo:4000",
26
- "gpt-4:6000",
27
  ]
28
 
29
  Base = declarative_base()
@@ -34,7 +33,7 @@ class User(Base):
34
 
35
  id = Column(Integer, primary_key=True, autoincrement=True)
36
  uid = Column(String, nullable=False)
37
- openai_key = Column(String)
38
 
39
  def __repr__(self):
40
  return f"User(id={self.id}, uid={self.uid}"
@@ -46,7 +45,7 @@ class History(Base):
46
  id = Column(Integer, primary_key=True, autoincrement=True)
47
  uid = Column(String, nullable=False)
48
  question = Column(String, nullable=False)
49
- answer = Column(JSON, nullable=False)
50
 
51
  def __repr__(self):
52
  return f"History(id={self.id}, uid={self.uid}, question={self.question}, answer={self.answer}"
@@ -135,11 +134,15 @@ class RegisterItem(BaseModel):
135
  @app.post("/register")
136
  async def register(item: RegisterItem):
137
  db: Session = SessionLocal()
138
- new_user = User(uid=str(uuid.uuid4()), openai_key=item.openai_key)
139
- db.add(new_user)
140
- db.commit()
141
- db.refresh(new_user)
142
- return {"uid": new_user.uid}
 
 
 
 
143
 
144
 
145
  class AssistItem(BaseModel):
@@ -148,28 +151,6 @@ class AssistItem(BaseModel):
148
  version: str
149
 
150
 
151
- def generate_quick_completion(prompt, model):
152
- dropdrown = model.split(":")
153
- engine = dropdrown[0]
154
- max_tokens = int(dropdrown[1])
155
- if "gpt" in model:
156
- message = [{"role": "user", "content": prompt}]
157
- response = openai.ChatCompletion.create(
158
- model=engine,
159
- messages=message,
160
- temperature=0.2,
161
- max_tokens=max_tokens,
162
- frequency_penalty=0.0,
163
- )
164
- elif "davinci" in model:
165
- response = openai.Completion.create(
166
- engine=engine, prompt=prompt, max_tokens=max_tokens
167
- )
168
- else:
169
- raise Exception("Unknown model")
170
- return response
171
-
172
-
173
  @app.post("/assist")
174
  async def assist(item: AssistItem):
175
  db: Session = SessionLocal()
@@ -179,15 +160,13 @@ async def assist(item: AssistItem):
179
 
180
  # Call OpenAI
181
  openai.api_key = user.openai_key
182
- response = generate_quick_completion(item.prompt, item.version)
183
 
184
  # Update the last time assist was called
185
  user.last_assist = datetime.now()
186
 
187
  # Store the history
188
- new_history = History(
189
- uid=item.uid, question=item.prompt, answer=json.loads(str(response))
190
- )
191
 
192
  db.add(new_history)
193
  db.commit()
@@ -228,7 +207,7 @@ def assist_interface(uid, prompt, gpt_version):
228
  "/assist",
229
  json={"uid": uid, "prompt": prompt, "version": gpt_version},
230
  )
231
- return gradio_user_output_helper(response.text)
232
 
233
 
234
  def get_user_interface(uid):
@@ -248,51 +227,31 @@ class HighlightRenderer(mistune.HTMLRenderer):
248
  return "<pre><code>" + mistune.escape(code) + "</code></pre>"
249
 
250
 
251
- def gradio_user_output_helper(data):
252
  r"""
253
  This is used by the gradio to extract all of the user
254
  data and write it out as a giant json blob that can be easily diplayed.
255
- >>> choices = [{'message': {'content': 'This is a test'}}]
256
- >>> data = { 'id': '1', 'object': 'user', 'created': '2021-09-01', 'model': 'gpt-3', 'choices': choices}
257
- >>> gradio_user_output_helper(json.dumps(data))
258
- '<html><h2>ID: 1</h2><p>Object: user</p><p>Created: 2021-09-01</p><p>Model: gpt-3</p><h3>Choices:</h3><p>Text: <p>This is a test</p>\n</p></html>'
259
  """
260
- html_output = "<html>"
261
- json_data = json.loads(data)
262
-
263
- id = json_data["id"]
264
- object = json_data["object"]
265
- created = json_data["created"]
266
- model = json_data["model"]
267
- choices = json_data["choices"]
268
-
269
- html_output += f"<h2>ID: {id}</h2>"
270
- html_output += f"<p>Object: {object}</p>"
271
- html_output += f"<p>Created: {created}</p>"
272
- html_output += f"<p>Model: {model}</p>"
273
-
274
- html_output += "<h3>Choices:</h3>"
275
- if "davinci" in model:
276
- for choice in choices:
277
- text = choice["text"]
278
- html_output += f"<p>Text: {text}</p>"
279
- elif "gpt" in model:
280
- for choice in choices:
281
- markdown = mistune.create_markdown(renderer=HighlightRenderer())
282
- text = markdown(choice["message"]["content"])
283
- html_output += f"<p>Text: {text}</p>"
284
- html_output += "</html>"
285
- return html_output
286
 
287
 
288
  def get_assist_interface():
289
- gpt_version_dropdown = gr.components.Dropdown(label="GPT Version", choices=LANGS)
290
 
291
- return gr.Interface(
292
  fn=assist_interface,
293
  inputs=[
294
- gr.components.Textbox(label="UID", type="text"),
295
- gr.components.Textbox(label="Prompt", type="text"),
296
  gpt_version_dropdown,
297
  ],
298
  outputs="html",
@@ -302,7 +261,7 @@ def get_assist_interface():
302
 
303
 
304
  def get_db_interface():
305
- return gr.Interface(
306
  fn=get_user_interface,
307
  inputs="text",
308
  outputs="text",
@@ -321,9 +280,9 @@ def register_interface(openai_key):
321
 
322
 
323
  def get_register_interface():
324
- return gr.Interface(
325
  fn=register_interface,
326
- inputs=[gr.components.Textbox(label="OpenAI Key", type="text")],
327
  outputs="json",
328
  title="Register New User",
329
  description="Register a new user by entering an OpenAI key.",
@@ -337,9 +296,9 @@ def get_history_interface(uid):
337
 
338
 
339
  def get_history_gradio_interface():
340
- return gr.Interface(
341
  fn=get_history_interface,
342
- inputs=[gr.components.Textbox(label="UID", type="text")],
343
  outputs="json",
344
  title="Get User History",
345
  description="Get the history of questions and answers for a given user.",
@@ -356,11 +315,11 @@ def add_command_interface(uid, command):
356
 
357
 
358
  def get_add_command_interface():
359
- return gr.Interface(
360
  fn=add_command_interface,
361
  inputs=[
362
- gr.components.Textbox(label="UID", type="text"),
363
- gr.components.Textbox(label="Command", type="text"),
364
  ],
365
  outputs="json",
366
  title="Add Command",
@@ -368,9 +327,9 @@ def get_add_command_interface():
368
  )
369
 
370
 
371
- app = gr.mount_gradio_app(
372
  app,
373
- gr.TabbedInterface(
374
  [
375
  get_assist_interface(),
376
  get_db_interface(),
 
2
  import uuid
3
  from datetime import datetime
4
 
 
5
  import mistune
6
  import openai
7
  from dotenv import load_dotenv
8
+ from easycompletion import openai_text_call
9
  from fastapi import FastAPI, HTTPException
10
+ from fastapi.middleware.gzip import GZipMiddleware
11
  from fastapi.testclient import TestClient
12
+ from gradio import Interface, TabbedInterface, components, mount_gradio_app
13
  from pydantic import BaseModel
14
  from pygments import highlight
15
  from pygments.formatters import html
 
19
  from sqlalchemy.orm import Session, sessionmaker
20
  from sqlalchemy.sql import insert, select, text
21
  from uvicorn import Config, Server
 
22
 
23
  LANGS = [
24
+ "gpt-3.5-turbo",
25
+ "gpt-4",
 
 
26
  ]
27
 
28
  Base = declarative_base()
 
33
 
34
  id = Column(Integer, primary_key=True, autoincrement=True)
35
  uid = Column(String, nullable=False)
36
+ openai_key = Column(String, unique=True, nullable=False)
37
 
38
  def __repr__(self):
39
  return f"User(id={self.id}, uid={self.uid}"
 
45
  id = Column(Integer, primary_key=True, autoincrement=True)
46
  uid = Column(String, nullable=False)
47
  question = Column(String, nullable=False)
48
+ answer = Column(String, nullable=False)
49
 
50
  def __repr__(self):
51
  return f"History(id={self.id}, uid={self.uid}, question={self.question}, answer={self.answer}"
 
134
  @app.post("/register")
135
  async def register(item: RegisterItem):
136
  db: Session = SessionLocal()
137
+ existing_user = db.query(User).filter(User.openai_key == item.openai_key).first()
138
+ if existing_user:
139
+ return {"uid": existing_user.uid} # return existing UUID
140
+ else:
141
+ new_user = User(uid=str(uuid.uuid4()), openai_key=item.openai_key)
142
+ db.add(new_user)
143
+ db.commit()
144
+ db.refresh(new_user)
145
+ return {"uid": new_user.uid}
146
 
147
 
148
  class AssistItem(BaseModel):
 
151
  version: str
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  @app.post("/assist")
155
  async def assist(item: AssistItem):
156
  db: Session = SessionLocal()
 
160
 
161
  # Call OpenAI
162
  openai.api_key = user.openai_key
163
+ response = openai_text_call(item.prompt, model=item.version)
164
 
165
  # Update the last time assist was called
166
  user.last_assist = datetime.now()
167
 
168
  # Store the history
169
+ new_history = History(uid=item.uid, question=item.prompt, answer=response["text"])
 
 
170
 
171
  db.add(new_history)
172
  db.commit()
 
207
  "/assist",
208
  json={"uid": uid, "prompt": prompt, "version": gpt_version},
209
  )
210
+ return generate_html_response_from_openai(response.text)
211
 
212
 
213
  def get_user_interface(uid):
 
227
  return "<pre><code>" + mistune.escape(code) + "</code></pre>"
228
 
229
 
230
+ def generate_html_response_from_openai(openai_response):
231
  r"""
232
  This is used by the gradio to extract all of the user
233
  data and write it out as a giant json blob that can be easily diplayed.
234
+ >>>
235
+ >>> data = {'text': 'This is a test'}
236
+ >>> generate_html_response_from_openai(json.dumps(data))
237
+ '<html><p>This is a test</p>\n</html>'
238
  """
239
+
240
+ openai_response = json.loads(openai_response)
241
+ openai_response = openai_response["text"]
242
+ markdown = mistune.create_markdown(renderer=HighlightRenderer())
243
+ openai_response = markdown(openai_response)
244
+ return f"<html>{openai_response}</html>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
 
247
  def get_assist_interface():
248
+ gpt_version_dropdown = components.Dropdown(label="GPT Version", choices=LANGS)
249
 
250
+ return Interface(
251
  fn=assist_interface,
252
  inputs=[
253
+ components.Textbox(label="UID", type="text"),
254
+ components.Textbox(label="Prompt", type="text"),
255
  gpt_version_dropdown,
256
  ],
257
  outputs="html",
 
261
 
262
 
263
  def get_db_interface():
264
+ return Interface(
265
  fn=get_user_interface,
266
  inputs="text",
267
  outputs="text",
 
280
 
281
 
282
  def get_register_interface():
283
+ return Interface(
284
  fn=register_interface,
285
+ inputs=[components.Textbox(label="OpenAI Key", type="text")],
286
  outputs="json",
287
  title="Register New User",
288
  description="Register a new user by entering an OpenAI key.",
 
296
 
297
 
298
  def get_history_gradio_interface():
299
+ return Interface(
300
  fn=get_history_interface,
301
+ inputs=[components.Textbox(label="UID", type="text")],
302
  outputs="json",
303
  title="Get User History",
304
  description="Get the history of questions and answers for a given user.",
 
315
 
316
 
317
  def get_add_command_interface():
318
+ return Interface(
319
  fn=add_command_interface,
320
  inputs=[
321
+ components.Textbox(label="UID", type="text"),
322
+ components.Textbox(label="Command", type="text"),
323
  ],
324
  outputs="json",
325
  title="Add Command",
 
327
  )
328
 
329
 
330
+ app = mount_gradio_app(
331
  app,
332
+ TabbedInterface(
333
  [
334
  get_assist_interface(),
335
  get_db_interface(),
backend/crappy_test.py CHANGED
@@ -9,11 +9,33 @@ client = TestClient(backend.app)
9
 
10
 
11
  @pytest.mark.asyncio
12
- async def test_register():
13
- data = {
14
- "openai_key": "garbage",
15
- }
16
- client.post("/register", json=data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  @pytest.mark.asyncio
 
9
 
10
 
11
  @pytest.mark.asyncio
12
+ async def test_register_new_key():
13
+ data = {"openai_key": "new_key"}
14
+ response = client.post("/register", json=data)
15
+ assert response.status_code == 200
16
+ assert "uid" in response.json()
17
+ uid = response.json()["uid"]
18
+
19
+ data = {"openai_key": "new_key"}
20
+ response = client.post("/register", json=data)
21
+ assert response.status_code == 200
22
+ assert "uid" in response.json()
23
+ assert response.json()["uid"] == uid
24
+
25
+
26
+ @pytest.mark.asyncio
27
+ async def test_register_existing_key():
28
+ data = {"openai_key": "existing_key"}
29
+ response = client.post("/register", json=data)
30
+ assert response.status_code == 200
31
+ assert "uid" in response.json()
32
+ uid = response.json()["uid"]
33
+
34
+ data = {"openai_key": "existing_key"}
35
+ response = client.post("/register", json=data)
36
+ assert response.status_code == 200
37
+ assert "uid" in response.json()
38
+ assert response.json()["uid"] == uid
39
 
40
 
41
  @pytest.mark.asyncio
backend/requirements.txt CHANGED
@@ -1,10 +1,7 @@
1
- fastapi
2
- firebase-admin
3
- flask
4
- gradio >= 3.36.1
5
- httpx
6
- openai
7
- pinecone-io
8
  pydantic
9
  pygments
10
  python-dotenv
@@ -12,7 +9,7 @@ mistune
12
  sqlalchemy
13
  swagger-ui-bundle
14
 
 
15
  pytest-asyncio
16
  pytest
17
- requests
18
-
 
1
+ fastapi == 0.99.1
2
+ easycompletion
3
+ gradio
4
+ openai >= 0.27.8
 
 
 
5
  pydantic
6
  pygments
7
  python-dotenv
 
9
  sqlalchemy
10
  swagger-ui-bundle
11
 
12
+ black
13
  pytest-asyncio
14
  pytest
15
+ requests