posix4e commited on
Commit
dc7f995
·
1 Parent(s): 5eb95c5

Rework for uvicorn

Browse files

- Add backend test
- Add gradio interface

.github/workflows/backend_ci.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Backend CI
2
+ on:
3
+ push:
4
+ branches: [main]
5
+ pull_request:
6
+ workflow_dispatch:
7
+
8
+ jobs:
9
+ test:
10
+ defaults:
11
+ run:
12
+ working-directory: backend
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - uses: actions/checkout@v3
16
+ - uses: actions/setup-python@v4
17
+ with:
18
+ python-version: '3.11.4'
19
+ cache: 'pip'
20
+ - name: pip Install
21
+ run: pip install -r requirements.txt
22
+ - name: Test
23
+ run: pytest
backend/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
backend/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backend for the puppet
2
+
3
+
4
+ This servers receives everything going on on the puppet. It can send requests to the puppet to do action. it uses openai as a llm currents
5
+
6
+ # dev
7
+
8
+ - (optional) Install virtualenv or equiv please
9
+ - pip install -r requirements.txt
10
+ - pytest
11
+
12
+ # you can also run the test server with
13
+ - uvicorn --host 0.0.0.0 --port 8000 backend:app --reload
14
+
backend/backend.py CHANGED
@@ -2,13 +2,19 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from firebase_admin import credentials, messaging, initialize_app, auth
4
  import openai
5
- import uuid
6
  import asyncio
 
 
 
 
 
 
 
 
7
 
8
  app = FastAPI()
9
 
10
- # To store user information in memory.
11
- # WARNING: This information will be lost when the server stops/restarts
12
  user_data = {}
13
 
14
 
@@ -17,30 +23,26 @@ class RegisterItem(BaseModel):
17
  authDomain: str
18
  databaseURL: str
19
  storageBucket: str
20
- openai_key: str
21
 
22
 
23
  @app.post("/register")
24
  async def register(item: RegisterItem):
25
- # Generate a unique user id
26
- uid = str(uuid.uuid4())
27
-
28
- # Initialize the Firebase app
29
- user_data[uid] = {
30
- "firebase": credentials.Certificate(
31
- {
32
- "apiKey": item.apiKey,
33
- "authDomain": item.authDomain,
34
- "databaseURL": item.databaseURL,
35
- "storageBucket": item.storageBucket,
36
- }
37
- ),
38
- "openai_key": item.openai_key,
39
  }
40
-
41
- initialize_app(user_data[uid]["firebase"], name=uid)
42
-
43
- return {"uid": uid}
44
 
45
 
46
  class ProcessItem(BaseModel):
@@ -50,15 +52,13 @@ class ProcessItem(BaseModel):
50
 
51
  @app.post("/process_request")
52
  async def process_request(item: ProcessItem):
53
- # Get the user data using the provided uid
54
- user = user_data.get(item.uid)
 
55
 
56
- if not user:
57
  raise HTTPException(status_code=400, detail="Invalid uid")
58
 
59
- # Set the OpenAI key for this user
60
- openai.api_key = user["openai_key"]
61
-
62
  # Call OpenAI
63
  response = openai.Completion.create(
64
  engine="text-davinci-002", prompt=item.prompt, max_tokens=150
@@ -70,21 +70,75 @@ async def process_request(item: ProcessItem):
70
  "message": response.choices[0].text.strip(),
71
  },
72
  topic="updates",
 
73
  )
74
 
75
  # Send the message asynchronously
76
- asyncio.run(send_notification(message, item.uid))
77
 
78
  return {"message": "Notification sent"}
79
 
80
 
81
- def send_notification(message, uid):
82
  # Send a message to the devices subscribed to the provided topic.
83
- response = messaging.send(message, app=user_data[uid]["firebase"])
84
  print("Successfully sent message:", response)
85
 
86
 
87
- if __name__ == "__main__":
88
- import uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  from pydantic import BaseModel
3
  from firebase_admin import credentials, messaging, initialize_app, auth
4
  import openai
 
5
  import asyncio
6
+ import gradio as gr
7
+ from fastapi.middleware.wsgi import WSGIMiddleware
8
+ from fastapi.staticfiles import StaticFiles
9
+ import uvicorn
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
 
15
  app = FastAPI()
16
 
17
+ # User credentials will be stored in this dictionary
 
18
  user_data = {}
19
 
20
 
 
23
  authDomain: str
24
  databaseURL: str
25
  storageBucket: str
 
26
 
27
 
28
  @app.post("/register")
29
  async def register(item: RegisterItem):
30
+ # Firebase initialization with user-specific credentials
31
+ cred = credentials.Certificate(
32
+ {
33
+ "apiKey": item.apiKey,
34
+ "authDomain": item.authDomain,
35
+ "databaseURL": item.databaseURL,
36
+ "storageBucket": item.storageBucket,
37
+ }
38
+ )
39
+ firebase_app = initialize_app(cred, name=str(len(user_data)))
40
+ # Add the Firebase app and auth details to the user_data dictionary
41
+ user_data[str(len(user_data))] = {
42
+ "firebase_app": firebase_app,
43
+ "authDomain": item.authDomain,
44
  }
45
+ return {"uid": str(len(user_data) - 1)} # Return the user ID
 
 
 
46
 
47
 
48
  class ProcessItem(BaseModel):
 
52
 
53
  @app.post("/process_request")
54
  async def process_request(item: ProcessItem):
55
+ # Get the user's Firebase app from the user_data dictionary
56
+ firebase_app = user_data.get(item.uid, {}).get("firebase_app", None)
57
+ authDomain = user_data.get(item.uid, {}).get("authDomain", None)
58
 
59
+ if not firebase_app or not authDomain:
60
  raise HTTPException(status_code=400, detail="Invalid uid")
61
 
 
 
 
62
  # Call OpenAI
63
  response = openai.Completion.create(
64
  engine="text-davinci-002", prompt=item.prompt, max_tokens=150
 
70
  "message": response.choices[0].text.strip(),
71
  },
72
  topic="updates",
73
+ app=firebase_app, # Use the user-specific Firebase app
74
  )
75
 
76
  # Send the message asynchronously
77
+ asyncio.run(send_notification(message))
78
 
79
  return {"message": "Notification sent"}
80
 
81
 
82
+ def send_notification(message):
83
  # Send a message to the devices subscribed to the provided topic.
84
+ response = messaging.send(message)
85
  print("Successfully sent message:", response)
86
 
87
 
88
+ def gradio_interface():
89
+ def register(apiKey, authDomain, databaseURL, storageBucket):
90
+ response = requests.post(
91
+ "http://localhost:8000/register",
92
+ json={
93
+ "apiKey": apiKey,
94
+ "authDomain": authDomain,
95
+ "databaseURL": databaseURL,
96
+ "storageBucket": storageBucket,
97
+ },
98
+ )
99
+ return response.json()
100
+
101
+ def process_request(uid, prompt):
102
+ response = requests.post(
103
+ "http://localhost:8000/process_request", json={"uid": uid, "prompt": prompt}
104
+ )
105
+ return response.json()
106
+
107
+ demo = gr.Interface(
108
+ fn=[register, process_request],
109
+ inputs=[
110
+ [
111
+ gr.inputs.Textbox(label="apiKey"),
112
+ gr.inputs.Textbox(label="authDomain"),
113
+ gr.inputs.Textbox(label="databaseURL"),
114
+ gr.inputs.Textbox(label="storageBucket"),
115
+ ],
116
+ [gr.inputs.Textbox(label="uid"), gr.inputs.Textbox(label="prompt")],
117
+ ],
118
+ outputs="json",
119
+ title="API Explorer",
120
+ description="Use this tool to make requests to the Register and Process Request APIs",
121
+ )
122
+ return demo
123
+
124
+
125
+ def process_request_interface(uid, prompt):
126
+ item = ProcessItem(uid=uid, prompt=prompt)
127
+ response = process_request(item)
128
+ return response
129
+
130
+
131
+ def get_gradle_interface():
132
+ return gr.Interface(
133
+ fn=process_request_interface,
134
+ inputs=[
135
+ gr.inputs.Textbox(label="UID", type="text"),
136
+ gr.inputs.Textbox(label="Prompt", type="text"),
137
+ ],
138
+ outputs="text",
139
+ title="OpenAI Text Generation",
140
+ description="Generate text using OpenAI's GPT-3 model.",
141
+ )
142
+
143
 
144
+ app = gr.mount_gradio_app(app, get_gradle_interface(), path="/")
backend/crappy_test.py CHANGED
@@ -1,8 +1,15 @@
1
- import requests
 
 
2
  import json
 
 
3
 
4
 
5
- def test_register():
 
 
 
6
  data = {
7
  "apiKey": "test-api-key",
8
  "authDomain": "test-auth-domain",
@@ -10,24 +17,5 @@ def test_register():
10
  "storageBucket": "test-storage-bucket",
11
  "openai_key": "test-openai-key",
12
  }
13
- response = requests.post(
14
- "http://localhost:8000/register",
15
- data=json.dumps(data),
16
- headers={"Content-Type": "application/json"},
17
- )
18
- assert response.status_code == 200
19
- uid = response.json().get("uid")
20
- assert uid is not None
21
-
22
- # Let's use this uid to send a process request
23
- data = {
24
- "uid": uid,
25
- "prompt": "Hello, OpenAI!",
26
- }
27
- response = requests.post(
28
- "http://localhost:8000/process_request",
29
- data=json.dumps(data),
30
- headers={"Content-Type": "application/json"},
31
- )
32
- assert response.status_code == 200
33
- assert response.json().get("message") == "Notification sent"
 
1
+ import time
2
+ from fastapi.testclient import TestClient
3
+ import pytest
4
  import json
5
+ import backend
6
+ import httpx
7
 
8
 
9
+ @pytest.mark.asyncio
10
+ async def test_register():
11
+ client = TestClient(backend.app)
12
+
13
  data = {
14
  "apiKey": "test-api-key",
15
  "authDomain": "test-auth-domain",
 
17
  "storageBucket": "test-storage-bucket",
18
  "openai_key": "test-openai-key",
19
  }
20
+ with pytest.raises(ValueError):
21
+ client.post("/register", json=data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/requirements.txt CHANGED
@@ -1,10 +1,15 @@
1
- pytest
2
- requests
3
  fastapi
4
  firebase-admin
 
5
  gradio
 
6
  openai
7
  pinecone-io
8
  pydantic
9
  python-dotenv
10
  swagger-ui-bundle
 
 
 
 
 
 
 
 
1
  fastapi
2
  firebase-admin
3
+ flask
4
  gradio
5
+ httpx
6
  openai
7
  pinecone-io
8
  pydantic
9
  python-dotenv
10
  swagger-ui-bundle
11
+
12
+ pytest-asyncio
13
+ pytest
14
+ requests
15
+