Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import aiohttp | |
| import requests | |
| import json | |
| import subprocess | |
| import asyncio | |
| from io import BytesIO | |
| import uuid | |
| from math import ceil | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from huggingface_hub import Repository | |
| from PIL import Image, ImageOps | |
| from fastapi import FastAPI, BackgroundTasks | |
| from fastapi_utils.tasks import repeat_every | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import boto3 | |
| from db import Database | |
| AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID') | |
| AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY') | |
| AWS_S3_BUCKET_NAME = os.getenv('AWS_S3_BUCKET_NAME') | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| S3_DATA_FOLDER = Path("sd-multiplayer-data") | |
| DB_FOLDER = Path("diffusers-gallery-data") | |
| CLASSIFIER_URL = "https://radames-aesthetic-style-nsfw-classifier.hf.space/run/inference" | |
| ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/" | |
| s3 = boto3.client(service_name='s3', | |
| aws_access_key_id=AWS_ACCESS_KEY_ID, | |
| aws_secret_access_key=AWS_SECRET_KEY) | |
| repo = Repository( | |
| local_dir=DB_FOLDER, | |
| repo_type="dataset", | |
| clone_from="huggingface-projects/diffusers-gallery-data", | |
| use_auth_token=True, | |
| ) | |
| repo.git_pull() | |
| database = Database(DB_FOLDER) | |
| async def upload_resize_image_url(session, image_url): | |
| print(f"Uploading image {image_url}") | |
| async with session.get(image_url) as response: | |
| if response.status == 200 and response.headers['content-type'].startswith('image'): | |
| image = Image.open(BytesIO(await response.read())).convert('RGB') | |
| # resize image proportional | |
| image = ImageOps.fit(image, (400, 400), Image.LANCZOS) | |
| image_bytes = BytesIO() | |
| image.save(image_bytes, format="JPEG") | |
| image_bytes.seek(0) | |
| fname = f'{uuid.uuid4()}.jpg' | |
| s3.upload_fileobj(Fileobj=image_bytes, Bucket=AWS_S3_BUCKET_NAME, Key="diffusers-gallery/" + fname, | |
| ExtraArgs={"ContentType": "image/jpeg", "CacheControl": "max-age=31536000"}) | |
| return fname | |
| return None | |
| def fetch_models(page=0): | |
| response = requests.get( | |
| f'https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}') | |
| data = response.json() | |
| return { | |
| "models": [model for model in data['models'] if not model['private']], | |
| "numItemsPerPage": data['numItemsPerPage'], | |
| "numTotalItems": data['numTotalItems'], | |
| "pageIndex": data['pageIndex'] | |
| } | |
| def fetch_model_card(model_id): | |
| response = requests.get( | |
| f'https://huggingface.co/{model_id}/raw/main/README.md') | |
| return response.text | |
| async def find_image_in_model_card(text): | |
| image_regex = re.compile(r'https?://\S+(?:png|jpg|jpeg|webp)') | |
| urls = re.findall(image_regex, text) | |
| if not urls: | |
| return [] | |
| async with aiohttp.ClientSession() as session: | |
| tasks = [asyncio.ensure_future(upload_resize_image_url( | |
| session, image_url)) for image_url in urls[0:3]] | |
| return await asyncio.gather(*tasks) | |
| def run_classifier(images): | |
| images = [i for i in images if i is not None] | |
| if len(images) > 0: | |
| # classifying only the first image | |
| images_urls = [ASSETS_URL + images[0]] | |
| response = requests.post(CLASSIFIER_URL, json={"data": [ | |
| {"urls": images_urls}, # json urls: list of images urls | |
| False, # enable/disable gallery image output | |
| None, # single image input | |
| None, # files input | |
| ]}).json() | |
| # data response is array data:[[{img0}, {img1}, {img2}...], Label, Gallery], | |
| class_data = response['data'][0][0] | |
| print(class_data) | |
| class_data_parsed = {row['label']: round( | |
| row['score'], 3) for row in class_data} | |
| # update row data with classificator data | |
| return class_data_parsed | |
| else: | |
| return {} | |
| async def get_all_new_models(): | |
| initial = fetch_models(0) | |
| num_pages = ceil(initial['numTotalItems'] / initial['numItemsPerPage']) | |
| print( | |
| f"Total items: {initial['numTotalItems']} - Items per page: {initial['numItemsPerPage']}") | |
| print(f"Found {num_pages} pages") | |
| # fetch all models | |
| new_models = [] | |
| for page in tqdm(range(0, num_pages)): | |
| print(f"Fetching page {page} of {num_pages}") | |
| page_models = fetch_models(page) | |
| new_models += page_models['models'] | |
| return new_models | |
| async def sync_data(): | |
| print("Fetching models") | |
| new_models = await get_all_new_models() | |
| print(f"Found {len(new_models)} models") | |
| # save list of all models for ids | |
| with open(DB_FOLDER / "models.json", "w") as f: | |
| json.dump(new_models, f) | |
| # with open(DB_FOLDER / "models.json", "r") as f: | |
| # new_models = json.load(f) | |
| new_models_ids = [model['id'] for model in new_models] | |
| # get existing models | |
| with database.get_db() as db: | |
| cursor = db.cursor() | |
| cursor.execute("SELECT id FROM models") | |
| existing_models = [row['id'] for row in cursor.fetchall()] | |
| models_ids_to_add = list(set(new_models_ids) - set(existing_models)) | |
| # find all models id to add from new_models | |
| models = [model for model in new_models if model['id'] in models_ids_to_add] | |
| print(f"Found {len(models)} new models") | |
| for model in tqdm(models): | |
| model_id = model['id'] | |
| model_card = fetch_model_card(model_id) | |
| images = await find_image_in_model_card(model_card) | |
| classifier = run_classifier(images) | |
| # update model row with image and classifier data | |
| with database.get_db() as db: | |
| cursor = db.cursor() | |
| cursor.execute("INSERT INTO models(id, data) VALUES (?, ?)", | |
| [model_id, json.dumps({ | |
| **model, | |
| "images": images, | |
| "class": classifier | |
| })]) | |
| db.commit() | |
| # print("Updating repository") | |
| # subprocess.Popen( | |
| # "git add . && git commit --amend -m 'update' && git push --force", cwd=DB_FOLDER, shell=True) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # @ app.get("/sync") | |
| # async def sync(background_tasks: BackgroundTasks): | |
| # await sync_data() | |
| # return "Synced data to huggingface datasets" | |
| MAX_PAGE_SIZE = 30 | |
| def get_page(page: int = 1): | |
| page = page if page > 0 else 1 | |
| with database.get_db() as db: | |
| cursor = db.cursor() | |
| cursor.execute(""" | |
| SELECT *, COUNT(*) OVER() AS total | |
| FROM models | |
| WHERE json_extract(data, '$.likes') > 4 | |
| ORDER BY datetime(json_extract(data, '$.lastModified')) DESC | |
| LIMIT ? OFFSET ? | |
| """, (MAX_PAGE_SIZE, (page - 1) * MAX_PAGE_SIZE)) | |
| results = cursor.fetchall() | |
| total = results[0]['total'] if results else 0 | |
| total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE | |
| return { | |
| "models": [json.loads(result['data']) for result in results], | |
| "totalPages": total_pages | |
| } | |
| def read_root(): | |
| return "Just a bot to sync data from diffusers gallery" | |
| # @app.on_event("startup") | |
| # @repeat_every(seconds=60 * 60 * 24, wait_first=True) | |
| # async def repeat_sync(): | |
| # await sync_data() | |
| # return "Synced data to huggingface datasets" | |