Spaces:
Runtime error
Runtime error
| import boto3 | |
| import os | |
| import re | |
| import json | |
| from pathlib import Path | |
| import sqlite3 | |
| from huggingface_hub import Repository, HfFolder | |
| import tqdm | |
| import subprocess | |
| from fastapi import FastAPI | |
| from fastapi_utils.tasks import repeat_every | |
| AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID') | |
| AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY') | |
| AWS_S3_BUCKET_NAME = os.getenv('AWS_S3_BUCKET_NAME') | |
| s3 = boto3.client(service_name='s3', | |
| aws_access_key_id=AWS_ACCESS_KEY_ID, | |
| aws_secret_access_key=AWS_SECRET_KEY) | |
| paginator = s3.get_paginator('list_objects_v2') | |
| S3_DATA_FOLDER = Path("sd-multiplayer-data") | |
| ROOMS_DATA_DB = S3_DATA_FOLDER / "rooms_data.db" | |
| repo = Repository( | |
| local_dir=S3_DATA_FOLDER, | |
| repo_type="dataset", | |
| clone_from="huggingface-projects/sd-multiplayer-data", | |
| use_auth_token=True, | |
| ) | |
| repo.git_pull() | |
| if not ROOMS_DATA_DB.exists(): | |
| print("Creating database") | |
| print("ROOMS_DATA_DB", ROOMS_DATA_DB) | |
| db = sqlite3.connect(ROOMS_DATA_DB) | |
| with open(Path("schema.sql"), "r") as f: | |
| db.executescript(f.read()) | |
| db.commit() | |
| db.close() | |
| def get_db(db_path): | |
| db = sqlite3.connect(db_path, check_same_thread=False) | |
| db.row_factory = sqlite3.Row | |
| try: | |
| yield db | |
| except Exception: | |
| db.rollback() | |
| finally: | |
| db.close() | |
| def sync_rooms_to_dataset(): | |
| for room_data_db in get_db(ROOMS_DATA_DB): | |
| rooms = room_data_db.execute("SELECT * FROM rooms").fetchall() | |
| cursor = room_data_db.cursor() | |
| for row in tqdm.tqdm(rooms): | |
| room_id = row["room_id"] | |
| print("syncing room data: ", room_id) | |
| objects = [] | |
| for result in paginator.paginate(Bucket=AWS_S3_BUCKET_NAME, Prefix=f'{room_id}/', Delimiter='/'): | |
| results = [] | |
| for obj in result.get('Contents'): | |
| try: | |
| key = obj.get('Key') | |
| time = obj.get('LastModified').isoformat() | |
| split_str = re.split(r'[-/.]', key) | |
| uuid = split_str[3] | |
| x, y = [int(n) | |
| for n in re.split(r'[_]', split_str[4])] | |
| prompt = ' '.join(split_str[4:]) | |
| results.append( | |
| {'x': x, 'y': y, 'prompt': prompt, 'time': time, 'key': key, 'uuid': uuid}) | |
| cursor.execute( | |
| 'INSERT INTO rooms_data VALUES (NULL, ?, ?, ?, ?, ?, ?, ?)', (room_id, uuid, x, y, prompt, time, key)) | |
| except Exception as e: | |
| print(e) | |
| continue | |
| objects += results | |
| room_data_db.commit() | |
| all_rows = [dict(row) for row in room_data_db.execute( | |
| "SELECT * FROM rooms_data WHERE room_id = ?", (room_id,)).fetchall()] | |
| data_path = S3_DATA_FOLDER / f"{room_id}.json" | |
| with open(data_path, 'w') as f: | |
| json.dump(all_rows, f, separators=(',', ':')) | |
| print("Updating repository") | |
| subprocess.Popen( | |
| "git add . && git commit --amend -m 'update' && git push --force", cwd=S3_DATA_FOLDER, shell=True) | |
| app = FastAPI() | |
| def read_root(): | |
| return "Just a bot to sync data to huggingface datasets and tweet tha latest data" | |
| def sync(): | |
| sync_rooms_to_dataset() | |
| return "Synced data to huggingface datasets" | |
| def repeat_sync(): | |
| sync_rooms_to_dataset() | |
| return "Synced data to huggingface datasets" |