Spaces:
Runtime error
Runtime error
| from enum import Enum | |
| import os | |
| import re | |
| import aiohttp | |
| import requests | |
| import json | |
| import subprocess | |
| import asyncio | |
| from io import BytesIO | |
| import uuid | |
| import yaml | |
| from math import ceil | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from huggingface_hub import Repository | |
| from PIL import Image, ImageOps | |
| from fastapi import FastAPI, BackgroundTasks | |
| from fastapi.responses import HTMLResponse | |
| from fastapi_utils.tasks import repeat_every | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import boto3 | |
| from datetime import datetime | |
| from db import Database | |
| AWS_ACCESS_KEY_ID = os.getenv("MY_AWS_ACCESS_KEY_ID") | |
| AWS_SECRET_KEY = os.getenv("MY_AWS_SECRET_KEY") | |
| AWS_S3_BUCKET_NAME = os.getenv("MY_AWS_S3_BUCKET_NAME") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| S3_DATA_FOLDER = Path("sd-multiplayer-data") | |
| DB_FOLDER = Path("diffusers-gallery-data") | |
| CLASSIFIER_URL = ( | |
| "https://radames-aesthetic-style-nsfw-classifier.hf.space/run/inference" | |
| ) | |
| ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/" | |
| BLOCKED_MODELS_REGEX = re.compile(r"(CyberHarem)", re.IGNORECASE) | |
| s3 = boto3.client( | |
| service_name="s3", | |
| aws_access_key_id=AWS_ACCESS_KEY_ID, | |
| aws_secret_access_key=AWS_SECRET_KEY, | |
| ) | |
| repo = Repository( | |
| local_dir=DB_FOLDER, | |
| repo_type="dataset", | |
| clone_from="huggingface-projects/diffusers-gallery-data", | |
| use_auth_token=True, | |
| ) | |
| repo.git_pull() | |
| database = Database(DB_FOLDER) | |
| async def upload_resize_image_url(session, image_url): | |
| print(f"Uploading image {image_url}") | |
| try: | |
| async with session.get(image_url) as response: | |
| if response.status == 200 and ( | |
| response.headers["content-type"].startswith("image") | |
| or response.headers["content-type"].startswith("application") | |
| ): | |
| image = Image.open(BytesIO(await response.read())).convert("RGB") | |
| # resize image proportional | |
| image = ImageOps.fit(image, (400, 400), Image.LANCZOS) | |
| image_bytes = BytesIO() | |
| image.save(image_bytes, format="JPEG") | |
| image_bytes.seek(0) | |
| fname = f"{uuid.uuid4()}.jpg" | |
| s3.upload_fileobj( | |
| Fileobj=image_bytes, | |
| Bucket=AWS_S3_BUCKET_NAME, | |
| Key="diffusers-gallery/" + fname, | |
| ExtraArgs={ | |
| "ContentType": "image/jpeg", | |
| "CacheControl": "max-age=31536000", | |
| }, | |
| ) | |
| return fname | |
| except Exception as e: | |
| print(f"Error uploading image {image_url}: {e}") | |
| return None | |
| def fetch_models(page=0): | |
| response = requests.get( | |
| f"https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}" | |
| ) | |
| data = response.json() | |
| return { | |
| "models": [model for model in data["models"] if not model["private"]], | |
| "numItemsPerPage": data["numItemsPerPage"], | |
| "numTotalItems": data["numTotalItems"], | |
| "pageIndex": data["pageIndex"], | |
| } | |
| def fetch_model_card(model_id): | |
| response = requests.get(f"https://huggingface.co/{model_id}/raw/main/README.md") | |
| return response.text | |
| REGEX = re.compile(r'---(.*?)---', re.DOTALL) | |
| def get_yaml_data(text_content): | |
| matches = REGEX.findall(text_content) | |
| yaml_block = matches[0].strip() if matches else None | |
| if yaml_block: | |
| try: | |
| data_dict = yaml.safe_load(yaml_block) | |
| return data_dict | |
| except yaml.YAMLError as exc: | |
| print(exc) | |
| return {} | |
| async def find_image_in_model_card(text, model_id): | |
| base_url = f"https://huggingface.co/{model_id}/resolve/main/" | |
| image_regex = re.compile(r"!\[.*\]\((.*?\.(png|jpg|jpeg|gif|bmp|webp))\)|src=\"(.*?\.(png|jpg|jpeg|gif|bmp|webp))\">", re.IGNORECASE) | |
| matches = image_regex.findall(text) | |
| urls = [] | |
| for match in matches: | |
| for url in match: | |
| if url: | |
| if not url.startswith("http") and not url.startswith("https"): | |
| url = base_url + url | |
| urls.append(url) | |
| if len(urls) == 0: | |
| return [] | |
| print(urls) | |
| async with aiohttp.ClientSession() as session: | |
| tasks = [ | |
| asyncio.ensure_future(upload_resize_image_url(session, image_url)) | |
| for image_url in urls[0:3] | |
| ] | |
| return await asyncio.gather(*tasks) | |
| def run_classifier(images): | |
| images = [i for i in images if i is not None] | |
| if len(images) > 0: | |
| # classifying only the first image | |
| images_urls = [ASSETS_URL + images[0]] | |
| response = requests.post( | |
| CLASSIFIER_URL, | |
| json={ | |
| "data": [ | |
| {"urls": images_urls}, # json urls: list of images urls | |
| False, # enable/disable gallery image output | |
| None, # single image input | |
| None, # files input | |
| ] | |
| }, | |
| ).json() | |
| # data response is array data:[[{img0}, {img1}, {img2}...], Label, Gallery], | |
| class_data = response["data"][0][0] | |
| class_data_parsed = {row["label"]: round(row["score"], 3) for row in class_data} | |
| # update row data with classificator data | |
| return class_data_parsed | |
| else: | |
| return {} | |
| async def get_all_new_models(): | |
| initial = fetch_models(0) | |
| num_pages = ceil(initial["numTotalItems"] / initial["numItemsPerPage"]) | |
| print( | |
| f"Total items: {initial['numTotalItems']} - Items per page: {initial['numItemsPerPage']}" | |
| ) | |
| print(f"Found {num_pages} pages") | |
| # fetch all models | |
| new_models = [] | |
| for page in tqdm(range(0, num_pages)): | |
| print(f"Fetching page {page} of {num_pages}") | |
| page_models = fetch_models(page) | |
| new_models += page_models["models"] | |
| return new_models | |
| async def sync_data(): | |
| print("Fetching models") | |
| repo.git_pull() | |
| all_models = await get_all_new_models() | |
| print(f"Found {len(all_models)} models") | |
| # save list of all models for ids | |
| with open(DB_FOLDER / "models.json", "w") as f: | |
| json.dump(all_models, f) | |
| # with open(DB_FOLDER / "models.json", "r") as f: | |
| # all_models = json.load(f) | |
| new_models_ids = [model["id"] for model in all_models] | |
| new_models_ids = [model_id for model_id in new_models_ids if not re.match(BLOCKED_MODELS_REGEX, model_id)] | |
| # get existing models | |
| with database.get_db() as db: | |
| cursor = db.cursor() | |
| cursor.execute("SELECT id FROM models") | |
| existing_models = [row["id"] for row in cursor.fetchall()] | |
| models_ids_to_add = list(set(new_models_ids) - set(existing_models)) | |
| # find all models id to add from new_models | |
| models = [model for model in all_models if model["id"] in models_ids_to_add] | |
| print(f"Found {len(models)} new models") | |
| for model in tqdm(models): | |
| model_id = model["id"] | |
| print(f"\n\nFetching model {model_id}") | |
| likes = model["likes"] | |
| downloads = model["downloads"] | |
| print("Fetching model card") | |
| model_card = fetch_model_card(model_id) | |
| print("Parsing model card") | |
| model_card_data = get_yaml_data(model_card) | |
| print("Finding images in model card") | |
| images = await find_image_in_model_card(model_card, model_id) | |
| classifier = run_classifier(images) | |
| print(images, classifier) | |
| # update model row with image and classifier data | |
| with database.get_db() as db: | |
| cursor = db.cursor() | |
| cursor.execute( | |
| "INSERT INTO models(id, data, likes, downloads) VALUES (?, ?, ?, ?)", | |
| [ | |
| model_id, | |
| json.dumps( | |
| { | |
| **model, | |
| "meta": model_card_data, | |
| "images": images, | |
| "class": classifier, | |
| } | |
| ), | |
| likes, | |
| downloads, | |
| ], | |
| ) | |
| db.commit() | |
| print("\n\n\n\nTry to update images again\n\n\n") | |
| with database.get_db() as db: | |
| cursor = db.cursor() | |
| cursor.execute("SELECT * from models") | |
| to_all_models = list(cursor.fetchall()) | |
| models_no_images = [] | |
| for model in to_all_models: | |
| model_data = json.loads(model["data"]) | |
| images = model_data["images"] | |
| filtered_images = [x for x in images if x is not None] | |
| if len(filtered_images) == 0: | |
| models_no_images.append(model) | |
| for model in tqdm(models_no_images): | |
| model_id = model["id"] | |
| model_data = json.loads(model["data"]) | |
| print(f"\n\nFetching model {model_id}") | |
| model_card = fetch_model_card(model_id) | |
| print("Parsing model card") | |
| model_card_data = get_yaml_data(model_card) | |
| print("Finding images in model card") | |
| images = await find_image_in_model_card(model_card, model_id) | |
| classifier = run_classifier(images) | |
| model_data["images"] = images | |
| model_data["class"] = classifier | |
| model_data["meta"] = model_card_data | |
| # update model row with image and classifier data | |
| with database.get_db() as db: | |
| cursor = db.cursor() | |
| cursor.execute( | |
| "UPDATE models SET data = ? WHERE id = ?", | |
| [json.dumps(model_data), model_id], | |
| ) | |
| db.commit() | |
| print("Update likes and downloads") | |
| for model in tqdm(all_models): | |
| model_id = model["id"] | |
| likes = model["likes"] | |
| downloads = model["downloads"] | |
| with database.get_db() as db: | |
| cursor = db.cursor() | |
| cursor.execute( | |
| "UPDATE models SET likes = ?, downloads = ? WHERE id = ?", | |
| [likes, downloads, model_id], | |
| ) | |
| db.commit() | |
| print("Updating DB repository") | |
| time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| cmd = f"git add . && git commit --amend -m 'update at {time}' && git push --force" | |
| print(cmd) | |
| subprocess.Popen(cmd, cwd=DB_FOLDER, shell=True) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # @ app.get("/sync") | |
| # async def sync(background_tasks: BackgroundTasks): | |
| # await sync_data() | |
| # return "Synced data to huggingface datasets" | |
| MAX_PAGE_SIZE = 30 | |
| class Sort(str, Enum): | |
| trending = "trending" | |
| recent = "recent" | |
| likes = "likes" | |
| class Style(str, Enum): | |
| all = "all" | |
| anime = "anime" | |
| s3D = "3d" | |
| realistic = "realistic" | |
| nsfw = "nsfw" | |
| lora = "lora" | |
| def get_page( | |
| page: int = 1, sort: Sort = Sort.trending, style: Style = Style.all, tag: str = None | |
| ): | |
| page = page if page > 0 else 1 | |
| if sort == Sort.trending: | |
| sort_query = "likes / MYPOWER((JULIANDAY('now') - JULIANDAY(datetime(json_extract(data, '$.lastModified')))) + 2, 2) DESC" | |
| elif sort == Sort.recent: | |
| sort_query = "datetime(json_extract(data, '$.lastModified')) DESC" | |
| elif sort == Sort.likes: | |
| sort_query = "likes DESC" | |
| if style == Style.all: | |
| style_query = "isNFSW = false" | |
| elif style == Style.anime: | |
| style_query = "json_extract(data, '$.class.anime') > 0.1 AND isNFSW = false" | |
| elif style == Style.s3D: | |
| style_query = "json_extract(data, '$.class.3d') > 0.1 AND isNFSW = false" | |
| elif style == Style.realistic: | |
| style_query = "json_extract(data, '$.class.real_life') > 0.1 AND isNFSW = false" | |
| elif style == Style.lora: | |
| style_query = "json_extract(data, '$.meta.tags') LIKE '%lora%' AND isNFSW = false" | |
| elif style == Style.nsfw: | |
| style_query = "isNFSW = true" | |
| with database.get_db() as db: | |
| cursor = db.cursor() | |
| cursor.execute( | |
| f""" | |
| SELECT *, | |
| COUNT(*) OVER() AS total, | |
| isNFSW | |
| FROM ( | |
| SELECT *, | |
| json_extract(data, '$.class.explicit') > 0.3 OR json_extract(data, '$.class.suggestive') > 0.3 AS isNFSW | |
| FROM models | |
| ) AS subquery | |
| WHERE (? IS NULL AND likes > 1 OR ? IS NOT NULL) | |
| AND {style_query} | |
| AND (? IS NULL OR EXISTS ( | |
| SELECT 1 | |
| FROM json_each(json_extract(data, '$.meta.tags')) | |
| WHERE json_each.value = ? | |
| )) | |
| ORDER BY {sort_query} | |
| LIMIT {MAX_PAGE_SIZE} OFFSET {(page - 1) * MAX_PAGE_SIZE}; | |
| """, | |
| (tag, tag, tag, tag), | |
| ) | |
| results = cursor.fetchall() | |
| total = results[0]["total"] if results else 0 | |
| total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE | |
| models_data = [] | |
| for result in results: | |
| data = json.loads(result["data"]) | |
| images = data["images"] | |
| filtered_images = [x for x in images if x is not None] | |
| # clean nulls | |
| data["images"] = filtered_images | |
| # update downloads and likes from db table | |
| data["downloads"] = result["downloads"] | |
| data["likes"] = result["likes"] | |
| data["isNFSW"] = bool(result["isNFSW"]) | |
| models_data.append(data) | |
| return {"models": models_data, "totalPages": total_pages} | |
| def read_root(): | |
| # return html page from string | |
| return HTMLResponse( | |
| """ | |
| <p>Just a bot to sync data from diffusers gallery please go to | |
| <a href="https://huggingface.co/spaces/huggingface-projects/diffusers-gallery" target="_blank" rel="noopener noreferrer">https://huggingface.co/spaces/huggingface-projects/diffusers-gallery</a> | |
| </p>""" | |
| ) | |
| async def repeat_sync(): | |
| await sync_data() | |
| return "Synced data to huggingface datasets" | |