Spaces:
Runtime error
Runtime error
api to get rooms data
Browse files- stablediffusion-infinity/app.py +58 -33
stablediffusion-infinity/app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import io
|
| 2 |
import os
|
|
|
|
| 3 |
|
| 4 |
from pathlib import Path
|
| 5 |
import uvicorn
|
|
@@ -26,6 +27,7 @@ import requests
|
|
| 26 |
import shortuuid
|
| 27 |
import re
|
| 28 |
import time
|
|
|
|
| 29 |
|
| 30 |
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
| 31 |
AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
|
|
@@ -37,22 +39,35 @@ FILE_TYPES = {
|
|
| 37 |
'image/png': 'png',
|
| 38 |
'image/jpeg': 'jpg',
|
| 39 |
}
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
|
| 42 |
app = FastAPI()
|
| 43 |
|
| 44 |
-
if not
|
| 45 |
print("Creating database")
|
| 46 |
-
print("
|
| 47 |
-
db = sqlite3.connect(
|
| 48 |
with open(Path("schema.sql"), "r") as f:
|
| 49 |
db.executescript(f.read())
|
| 50 |
db.commit()
|
| 51 |
db.close()
|
| 52 |
|
| 53 |
|
| 54 |
-
def
|
| 55 |
-
db = sqlite3.connect(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
db.row_factory = sqlite3.Row
|
| 57 |
try:
|
| 58 |
yield db
|
|
@@ -77,6 +92,11 @@ model = {}
|
|
| 77 |
STATIC_MASK = Image.open("mask.png")
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def get_model():
|
| 81 |
if "inpaint" not in model:
|
| 82 |
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
|
|
@@ -86,31 +106,9 @@ def get_model():
|
|
| 86 |
torch_dtype=torch.float16,
|
| 87 |
vae=vae,
|
| 88 |
).to("cuda")
|
| 89 |
-
# lms = LMSDiscreteScheduler(
|
| 90 |
-
# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
| 91 |
-
|
| 92 |
-
# img2img = StableDiffusionImg2ImgPipeline(
|
| 93 |
-
# vae=text2img.vae,
|
| 94 |
-
# text_encoder=text2img.text_encoder,
|
| 95 |
-
# tokenizer=text2img.tokenizer,
|
| 96 |
-
# unet=text2img.unet,
|
| 97 |
-
# scheduler=lms,
|
| 98 |
-
# safety_checker=text2img.safety_checker,
|
| 99 |
-
# feature_extractor=text2img.feature_extractor,
|
| 100 |
-
# ).to("cuda")
|
| 101 |
-
# try:
|
| 102 |
-
# total_memory = torch.cuda.get_device_properties(0).total_memory // (
|
| 103 |
-
# 1024 ** 3
|
| 104 |
-
# )
|
| 105 |
-
# if total_memory <= 5:
|
| 106 |
-
# inpaint.enable_attention_slicing()
|
| 107 |
-
# except:
|
| 108 |
-
# pass
|
| 109 |
model["inpaint"] = inpaint
|
| 110 |
-
# model["img2img"] = img2img
|
| 111 |
|
| 112 |
return model["inpaint"]
|
| 113 |
-
# model["img2img"]
|
| 114 |
|
| 115 |
|
| 116 |
# init model on startup
|
|
@@ -274,10 +272,10 @@ def get_room_count(room_id: str):
|
|
| 274 |
|
| 275 |
@ app.on_event("startup")
|
| 276 |
@ repeat_every(seconds=100)
|
| 277 |
-
|
|
|
|
| 278 |
try:
|
| 279 |
-
|
| 280 |
-
for db in get_db():
|
| 281 |
rooms = db.execute("SELECT * FROM rooms").fetchall()
|
| 282 |
for row in rooms:
|
| 283 |
room_id = row["room_id"]
|
|
@@ -291,14 +289,41 @@ async def sync_rooms():
|
|
| 291 |
print("Rooms update failed")
|
| 292 |
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
@ app.get('/api/rooms')
|
| 295 |
-
async def get_rooms(db: sqlite3.Connection = Depends(
|
|
|
|
| 296 |
rooms = db.execute("SELECT * FROM rooms").fetchall()
|
| 297 |
return rooms
|
| 298 |
|
| 299 |
|
| 300 |
@ app.post('/api/auth')
|
| 301 |
-
async def autorize(request: Request
|
| 302 |
data = await request.json()
|
| 303 |
room = data["room"]
|
| 304 |
payload = {
|
|
|
|
| 1 |
import io
|
| 2 |
import os
|
| 3 |
+
from typing import Union
|
| 4 |
|
| 5 |
from pathlib import Path
|
| 6 |
import uvicorn
|
|
|
|
| 27 |
import shortuuid
|
| 28 |
import re
|
| 29 |
import time
|
| 30 |
+
import subprocess
|
| 31 |
|
| 32 |
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
| 33 |
AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
|
|
|
|
| 39 |
'image/png': 'png',
|
| 40 |
'image/jpeg': 'jpg',
|
| 41 |
}
|
| 42 |
+
S3_DATA_FOLDER = Path("sd-multiplayer-data")
|
| 43 |
+
ROOMS_DATA_DB = S3_DATA_FOLDER / "rooms_data.db"
|
| 44 |
+
ROOM_DB = Path("rooms.db")
|
| 45 |
|
| 46 |
app = FastAPI()
|
| 47 |
|
| 48 |
+
if not ROOM_DB.exists():
|
| 49 |
print("Creating database")
|
| 50 |
+
print("ROOM_DB", ROOM_DB)
|
| 51 |
+
db = sqlite3.connect(ROOM_DB)
|
| 52 |
with open(Path("schema.sql"), "r") as f:
|
| 53 |
db.executescript(f.read())
|
| 54 |
db.commit()
|
| 55 |
db.close()
|
| 56 |
|
| 57 |
|
| 58 |
+
def get_room_db():
|
| 59 |
+
db = sqlite3.connect(ROOM_DB, check_same_thread=False)
|
| 60 |
+
db.row_factory = sqlite3.Row
|
| 61 |
+
try:
|
| 62 |
+
yield db
|
| 63 |
+
except Exception:
|
| 64 |
+
db.rollback()
|
| 65 |
+
finally:
|
| 66 |
+
db.close()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_room_data_db():
|
| 70 |
+
db = sqlite3.connect(ROOMS_DATA_DB, check_same_thread=False)
|
| 71 |
db.row_factory = sqlite3.Row
|
| 72 |
try:
|
| 73 |
yield db
|
|
|
|
| 92 |
STATIC_MASK = Image.open("mask.png")
|
| 93 |
|
| 94 |
|
| 95 |
+
def sync_rooms_data_repo():
|
| 96 |
+
subprocess.Popen("git fetch && git reset --hard origin/main",
|
| 97 |
+
cwd=S3_DATA_FOLDER, shell=True)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
def get_model():
|
| 101 |
if "inpaint" not in model:
|
| 102 |
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
|
|
|
|
| 106 |
torch_dtype=torch.float16,
|
| 107 |
vae=vae,
|
| 108 |
).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
model["inpaint"] = inpaint
|
|
|
|
| 110 |
|
| 111 |
return model["inpaint"]
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
# init model on startup
|
|
|
|
| 272 |
|
| 273 |
@ app.on_event("startup")
|
| 274 |
@ repeat_every(seconds=100)
|
| 275 |
+
def sync_rooms():
|
| 276 |
+
print("Syncing rooms active users")
|
| 277 |
try:
|
| 278 |
+
for db in get_room_db():
|
|
|
|
| 279 |
rooms = db.execute("SELECT * FROM rooms").fetchall()
|
| 280 |
for row in rooms:
|
| 281 |
room_id = row["room_id"]
|
|
|
|
| 289 |
print("Rooms update failed")
|
| 290 |
|
| 291 |
|
| 292 |
+
@ app.on_event("startup")
|
| 293 |
+
@ repeat_every(seconds=300)
|
| 294 |
+
def sync_room_datq():
|
| 295 |
+
print("Sync rooms data")
|
| 296 |
+
sync_rooms_data_repo()
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@ app.get('/api/room_data/{room_id}')
|
| 300 |
+
async def get_rooms(room_id: str, start: str = None, end: str = None, db: sqlite3.Connection = Depends(get_room_data_db)):
|
| 301 |
+
print("Getting rooms data", room_id, start, end)
|
| 302 |
+
|
| 303 |
+
if start is None and end is None:
|
| 304 |
+
rooms_rows = db.execute(
|
| 305 |
+
"SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? ORDER BY time", (room_id,)).fetchall()
|
| 306 |
+
elif end is None:
|
| 307 |
+
rooms_rows = db.execute("SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? AND time >= ? ORDER BY time",
|
| 308 |
+
(room_id, start)).fetchall()
|
| 309 |
+
elif start is None:
|
| 310 |
+
rooms_rows = db.execute("SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? AND time <= ? ORDER BY time",
|
| 311 |
+
(room_id, end)).fetchall()
|
| 312 |
+
else:
|
| 313 |
+
rooms_rows = db.execute("SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? AND time >= ? AND time <= ? ORDER BY time",
|
| 314 |
+
(room_id, start, end)).fetchall()
|
| 315 |
+
return rooms_rows
|
| 316 |
+
|
| 317 |
+
|
| 318 |
@ app.get('/api/rooms')
|
| 319 |
+
async def get_rooms(db: sqlite3.Connection = Depends(get_room_db)):
|
| 320 |
+
print("Getting rooms")
|
| 321 |
rooms = db.execute("SELECT * FROM rooms").fetchall()
|
| 322 |
return rooms
|
| 323 |
|
| 324 |
|
| 325 |
@ app.post('/api/auth')
|
| 326 |
+
async def autorize(request: Request):
|
| 327 |
data = await request.json()
|
| 328 |
room = data["room"]
|
| 329 |
payload = {
|