Spaces:
Runtime error
Runtime error
| import os | |
| import uvicorn | |
| import gradio as gr | |
| from pathlib import Path | |
| from huggingface_hub import Repository | |
| import json | |
| from db import Database | |
| from fastapi import FastAPI | |
| from datetime import datetime | |
| import subprocess | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| S3_DATA_FOLDER = Path("sd-multiplayer-data") | |
| DB_FOLDER = Path("diffusers-gallery-data") | |
| ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/" | |
| repo = Repository( | |
| local_dir=DB_FOLDER, | |
| repo_type="dataset", | |
| clone_from="huggingface-projects/diffusers-gallery-data", | |
| use_auth_token=True, | |
| ) | |
| repo.git_pull() | |
| database = Database(DB_FOLDER) | |
| styles_cls = ["anime", "3D", "realistic", "other"] | |
| nsfw_cls = ["safe", "suggestive", "explicit"] | |
| js_get_url_params = """ | |
| function (current_model, styles, nsfw) { | |
| const params = new URLSearchParams(window.location.search); | |
| current_model.model_id = params.get("model_id") || ""; | |
| window.history.replaceState({}, document.title, "/"); | |
| return [current_model, styles, nsfw] | |
| } | |
| """ | |
| def next_model(query_params, styles=None, nsfw=None): | |
| model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None | |
| print(model_id, styles, nsfw) | |
| with database.get_db() as db: | |
| if model_id: | |
| cursor = db.execute( | |
| """SELECT *, | |
| SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged | |
| FROM models | |
| WHERE id = ?""", (model_id,)) | |
| row = cursor.fetchone() | |
| if row is None: | |
| raise gr.Error("Cannot find model to annotate") | |
| else: | |
| cursor = db.execute( | |
| """SELECT *, | |
| SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged | |
| FROM models | |
| WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL | |
| ORDER BY RANDOM() | |
| LIMIT 1""") | |
| row = cursor.fetchone() | |
| if row is None: | |
| raise gr.Error("Cannot find any more models to annotate") | |
| total_unflagged = row["total_unflagged"] | |
| model_id = row["id"] | |
| data = json.loads(row["data"]) | |
| images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")] | |
| flags_data = json.loads(row["flags"] or "{}") | |
| styles = flags_data.get("styles", []) | |
| nsfw = flags_data.get("nsfw", None) | |
| title = f'''#### [Model {model_id}](https://huggingface.co/{model_id}) | |
| **Unflaggedd** {total_unflagged}''' | |
| return images, title, styles, nsfw, {"model_id": model_id} | |
| def flag_model(current_model, styles=None, nsfw=None): | |
| model_id = current_model["model_id"] | |
| print("Flagging model", model_id, styles, nsfw) | |
| with database.get_db() as db: | |
| db.execute( | |
| """UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id)) | |
| return next_model({}, styles, nsfw) | |
| blocks = gr.Blocks() | |
| with blocks: | |
| gr.Markdown('''### Diffusers Gallery annotation tool | |
| Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information. | |
| ''') | |
| model_title = gr.Markdown() | |
| gallery = gr.Gallery( | |
| label="Images", show_label=False, elem_id="gallery" | |
| ).style(grid=[3]) | |
| styles = gr.CheckboxGroup( | |
| styles_cls, info="Classify the image as one or more of the following classes") | |
| nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?") | |
| # invisible inputs to store the query params | |
| query_params = gr.JSON(value={}, visible=False) | |
| current_model = gr.State({}) | |
| next_btn = gr.Button("Next") | |
| submit_btn = gr.Button("Submit") | |
| next_btn.click(next_model, inputs=[query_params, styles, nsfw], | |
| outputs=[gallery, model_title, styles, nsfw, current_model]) | |
| submit_btn.click(flag_model, inputs=[current_model, styles, nsfw], outputs=[ | |
| gallery, model_title, styles, nsfw, current_model]) | |
| blocks.load(next_model, inputs=[query_params, styles, nsfw], | |
| outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params) | |
| app = FastAPI() | |
| def read_main(): | |
| sync_data() | |
| return "Synced flagged" | |
| def sync_data(): | |
| print("Updating DB repository") | |
| time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| subprocess.Popen( | |
| f"git add . && git commit --amend -m 'update at flags {time}' && git push --force", cwd=DB_FOLDER, shell=True) | |
| app = gr.mount_gradio_app(app, blocks, "/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host='0.0.0.0', port=7860) | |