Spaces:
Sleeping
Sleeping
remove space from requirements.txt
Browse files- app_imagegen_new.py +863 -0
- requirements.txt +0 -1
app_imagegen_new.py
ADDED
|
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers import UNet2DConditionModel, DDIMScheduler, StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, DPMSolverSDEScheduler, DPMSolverMultistepScheduler, AutoencoderKL, AutoencoderTiny, StableDiffusionXLImg2ImgPipeline
|
| 4 |
+
import ipown
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
from insightface.app import FaceAnalysis
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import cv2
|
| 9 |
+
import helper
|
| 10 |
+
import random
|
| 11 |
+
from transformers import Qwen2ForSequenceClassification, AutoTokenizer
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import base64
|
| 14 |
+
import json
|
| 15 |
+
import time
|
| 16 |
+
import jwt
|
| 17 |
+
import glob
|
| 18 |
+
import uuid
|
| 19 |
+
import re
|
| 20 |
+
import copy
|
| 21 |
+
from model import LoraHelper, ModelManager, Model
|
| 22 |
+
import numpy as np
|
| 23 |
+
from compel import Compel, ReturnedEmbeddingsType
|
| 24 |
+
from hidiffusion import apply_hidiffusion, remove_hidiffusion
|
| 25 |
+
import utils
|
| 26 |
+
from datetime import datetime,timezone
|
| 27 |
+
import os
|
| 28 |
+
from auth import AuthHelper
|
| 29 |
+
import download_for_imagegen
|
| 30 |
+
from safety import AgePredictor, NSFWClassifier, Qwen3Analyzer
|
| 31 |
+
from utils import generate_watermark, save_image
|
| 32 |
+
|
| 33 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 34 |
+
config = json.load(open("./config.json", "r"))
|
| 35 |
+
|
| 36 |
+
local_config = {}
|
| 37 |
+
try:
|
| 38 |
+
local_config = json.load(open("./local_config.json", "r"))
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(e)
|
| 41 |
+
|
| 42 |
+
local_config.update({
|
| 43 |
+
"lazy_load": os.environ.get("LAZY_LOAD", "False").lower() == "true", # Convert to boolean
|
| 44 |
+
"max_loaded_models": int(os.environ.get("MAX_LOADED_MODELS", 1)), # Convert to integer
|
| 45 |
+
})
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
base_model_path = "John6666/pony-realism-v21main-sdxl"
|
| 50 |
+
config["r2"]["access_key"] = os.getenv("R2_ACCESS_KEY")
|
| 51 |
+
config["r2"]["secret_key"] = os.getenv("R2_SECRET_KEY")
|
| 52 |
+
|
| 53 |
+
#
|
| 54 |
+
characters = utils.RemoteJson(config["characters"],1800, lambda lst: {item['key']: item for item in lst})
|
| 55 |
+
vibes = {item['key']: item for item in json.load(open("./vibe.json", "r"))}#utils.RemoteJson(config["vibes"],1800, lambda lst: {item['key']: item for item in lst})
|
| 56 |
+
styles = utils.RemoteJson(config["styles"],1800, lambda lst: {item['key']: item for item in lst})
|
| 57 |
+
#ip_xl_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sdxl.bin", repo_type="model")
|
| 58 |
+
#_ = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
| 59 |
+
|
| 60 |
+
# download weights
|
| 61 |
+
if not os.path.exists('realesr-general-x4v3.pth'):
|
| 62 |
+
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
|
| 63 |
+
|
| 64 |
+
device = "cuda"
|
| 65 |
+
scheduler_config = {
|
| 66 |
+
"num_train_timesteps": 1000,
|
| 67 |
+
"beta_start": 0.00085,
|
| 68 |
+
"beta_end": 0.012,
|
| 69 |
+
"beta_schedule": "scaled_linear",
|
| 70 |
+
"set_alpha_to_one": False,
|
| 71 |
+
"steps_offset": 1,
|
| 72 |
+
"prediction_type": "epsilon",
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
noise_scheduler = DDIMScheduler(
|
| 76 |
+
num_train_timesteps=1000,
|
| 77 |
+
#beta_start=0.00015,
|
| 78 |
+
#beta_end=0.02,
|
| 79 |
+
beta_start=0.00085,
|
| 80 |
+
beta_end=0.012,
|
| 81 |
+
beta_schedule="scaled_linear",
|
| 82 |
+
clip_sample=False,
|
| 83 |
+
set_alpha_to_one=False,
|
| 84 |
+
steps_offset=1,
|
| 85 |
+
)
|
| 86 |
+
euler = EulerAncestralDiscreteScheduler(
|
| 87 |
+
num_train_timesteps = 1000,
|
| 88 |
+
beta_start = 0.00085,
|
| 89 |
+
beta_end = 0.012,
|
| 90 |
+
#beta_schedule="scaled_linear",
|
| 91 |
+
#device = 'cuda',
|
| 92 |
+
steps_offset = 1,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
age_predictor = AgePredictor()
|
| 96 |
+
#nsfw_classifier = NSFWClassifier()
|
| 97 |
+
llm_analyzer = Qwen3Analyzer()
|
| 98 |
+
lora_mgr = LoraHelper()
|
| 99 |
+
# vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
|
| 100 |
+
#pipe.scheduler = euler
|
| 101 |
+
|
| 102 |
+
class PromptGenerator:
|
| 103 |
+
# Define the regular expression
|
| 104 |
+
child_related_regex = re.compile(
|
| 105 |
+
r'(child|children|kid|kids|baby|shota|loli|lolicon|babies|ll*oo*ll*ii*|miniature|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|kindergarten|preschool|young girl|young daughter|'
|
| 106 |
+
r'([1-9]|1[0-7])\s*year(s)?\s*old|' # Matches 1 to 17 years old
|
| 107 |
+
r'little|small|tiny|short|new\s*born\s*(boy|girl|bro|brother|sis|sister|shota|lolita|lolli))|'
|
| 108 |
+
r'(flat[\s_-]*chest(?:ed|s)?|small[\s_-]*chest(?:ed|s)?|medium[\s_-]*chest(?:ed|s)?|tiny[\s_-]*chest(?:ed|s)?|petite[\s_-]*chest(?:ed|s)?|underdeveloped[\s_-]*chest(?:ed|s)?)',
|
| 109 |
+
re.IGNORECASE
|
| 110 |
+
)
|
| 111 |
+
def __init__(self):
|
| 112 |
+
self.load_tag_generator()
|
| 113 |
+
|
| 114 |
+
# Function to remove child-related content from a prompt
|
| 115 |
+
def remove_child_related_content(prompt):
|
| 116 |
+
cleaned_prompt = re.sub(PromptGenerator.child_related_regex, '', prompt)
|
| 117 |
+
return cleaned_prompt.strip()
|
| 118 |
+
|
| 119 |
+
# Function to check if a prompt contains child-related content
|
| 120 |
+
def contains_child_related_content(prompt):
|
| 121 |
+
if PromptGenerator.child_related_regex.search(prompt):
|
| 122 |
+
return True
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
def load_tag_generator(self):
|
| 126 |
+
torch.set_grad_enabled(False)
|
| 127 |
+
print("loading tag generator model")
|
| 128 |
+
tag_config = config.get("tag_generator", {})
|
| 129 |
+
model = Qwen2ForSequenceClassification.from_pretrained(
|
| 130 |
+
tag_config.get("model_path"),
|
| 131 |
+
num_labels=9086,
|
| 132 |
+
device_map="cpu",
|
| 133 |
+
local_files_only=True,
|
| 134 |
+
|
| 135 |
+
)
|
| 136 |
+
model.eval()
|
| 137 |
+
tokenizer = AutoTokenizer.from_pretrained(tag_config.get("tokenizer_path"), local_files_only=True)
|
| 138 |
+
allowed_tags = None
|
| 139 |
+
with open("tags_9083.json", "r") as file:
|
| 140 |
+
allowed_tags = json.load(file)
|
| 141 |
+
allowed_tags = sorted(allowed_tags)
|
| 142 |
+
allowed_tags.append("explicit")
|
| 143 |
+
allowed_tags.append("questionable")
|
| 144 |
+
allowed_tags.append("safe")
|
| 145 |
+
|
| 146 |
+
self.model = model
|
| 147 |
+
self.tokenizer = tokenizer
|
| 148 |
+
self.allowed_tags = allowed_tags
|
| 149 |
+
print("done")
|
| 150 |
+
|
| 151 |
+
def create_danbooru_tags(self, prompt, threshold):
|
| 152 |
+
inputs = self.tokenizer(
|
| 153 |
+
prompt,
|
| 154 |
+
padding="do_not_pad",
|
| 155 |
+
max_length=512,
|
| 156 |
+
truncation=True,
|
| 157 |
+
return_tensors="pt",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
output = self.model(**inputs).logits
|
| 161 |
+
output = torch.nn.functional.sigmoid(output)
|
| 162 |
+
indices = torch.where(output > threshold)
|
| 163 |
+
values = output[indices]
|
| 164 |
+
indices = indices[1]
|
| 165 |
+
values = values.squeeze()
|
| 166 |
+
skip = set(["simple_background", "multiview","safe", "questionable","mammal"])
|
| 167 |
+
temp = []
|
| 168 |
+
tag_score = dict()
|
| 169 |
+
for i in range(indices.size(0)):
|
| 170 |
+
temp.append([self.allowed_tags[indices[i]], values[i].item()])
|
| 171 |
+
tag_score[self.allowed_tags[indices[i]]] = values[i].item()
|
| 172 |
+
temp = [t[0] for t in temp if t[0] not in skip]
|
| 173 |
+
text_no_impl = ",".join(temp)
|
| 174 |
+
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 175 |
+
print(f"{current_datetime}: finished.")
|
| 176 |
+
return text_no_impl, tag_score
|
| 177 |
+
|
| 178 |
+
def extract_prompt_elements(prompt):
|
| 179 |
+
chars = characters.get()
|
| 180 |
+
mentioned_chars = utils.extract_characters("@",prompt)
|
| 181 |
+
converted_chars = set()
|
| 182 |
+
chars_in_prompt = {}
|
| 183 |
+
boys = set()
|
| 184 |
+
girls = set()
|
| 185 |
+
for c in mentioned_chars:
|
| 186 |
+
char_info = chars.get(c["key"], {})
|
| 187 |
+
if char_info:
|
| 188 |
+
c["info"] = char_info
|
| 189 |
+
sample_prompts = char_info.get("samplePrompts", [])
|
| 190 |
+
gender = c.get("gender", "")
|
| 191 |
+
prompt = prompt.replace(c["raw"], char_info.get("name", char_info.get("key")).replace("_", " "))#TODO:replace with character name
|
| 192 |
+
c["long_prompt"] = (lambda s: s)(max(sample_prompts, key=len)).replace("1girl","").replace("1boy","").replace("_", " ")
|
| 193 |
+
c["short_prompt"] = (lambda s: s)(min(sample_prompts, key=len)).replace("1girl","").replace("1boy","").replace("_", " ")
|
| 194 |
+
if "girl" in " ".join(char_info.get("samplePrompts",[])):
|
| 195 |
+
girls.add(c["key"])
|
| 196 |
+
c["gender"] = "girl"
|
| 197 |
+
elif "boy" in " ".join(char_info.get("samplePrompts",[])):
|
| 198 |
+
boys.add(c["key"])
|
| 199 |
+
c["gender"] = "boy"
|
| 200 |
+
else:
|
| 201 |
+
c["gender"] = "girl"
|
| 202 |
+
girls.add(c["key"])
|
| 203 |
+
chars_in_prompt[c["key"]] = c
|
| 204 |
+
|
| 205 |
+
pattern = re.compile(
|
| 206 |
+
r'(\d+|one|two|three|multi|multiple)\s*(girl|girls|boy|boys|woman|women|man|man|male|female)',
|
| 207 |
+
re.IGNORECASE
|
| 208 |
+
)
|
| 209 |
+
matches = pattern.findall(prompt)
|
| 210 |
+
|
| 211 |
+
# Convert tuples to strings
|
| 212 |
+
char_counts = ",".join(["".join(match) for match in matches])
|
| 213 |
+
print(f"1 {char_counts}, {girls}, {boys}")
|
| 214 |
+
print(chars_in_prompt)
|
| 215 |
+
if not char_counts:
|
| 216 |
+
char_counts = ""
|
| 217 |
+
if boys:
|
| 218 |
+
if len(boys) == 1:
|
| 219 |
+
char_counts += "1boy,"
|
| 220 |
+
elif len(boys)> 1:
|
| 221 |
+
char_counts += "multiple boy,"
|
| 222 |
+
if girls:
|
| 223 |
+
if len(girls) == 1:
|
| 224 |
+
char_counts += "1girl,"
|
| 225 |
+
elif len(girls)> 1:
|
| 226 |
+
char_counts += "multiple girl,"
|
| 227 |
+
if len(chars_in_prompt) ==3:
|
| 228 |
+
char_counts = "threesome,"+char_counts
|
| 229 |
+
elif len(chars_in_prompt) >3:
|
| 230 |
+
char_counts = "gangbang,"+char_counts
|
| 231 |
+
|
| 232 |
+
print(f"2 {char_counts}")
|
| 233 |
+
return chars_in_prompt, char_counts,prompt
|
| 234 |
+
|
| 235 |
+
def compose_prompt_il(self, prompt, chars_in_prompt, char_counts, styles, extract_danbooru_tags=False):
|
| 236 |
+
#char count
|
| 237 |
+
prompt = PromptGenerator.remove_child_related_content(prompt)
|
| 238 |
+
|
| 239 |
+
if "vibe" in styles:
|
| 240 |
+
vibe_config = styles["vibe"]
|
| 241 |
+
vibe_style = vibe_config["styles"][0]
|
| 242 |
+
vibe_prompt = vibe_config["prompt"]
|
| 243 |
+
prompt = vibe_prompt.replace("{prompt}", prompt).replace("{style}", vibe_style)
|
| 244 |
+
styles.pop("vibe")
|
| 245 |
+
|
| 246 |
+
new_prompt = ""
|
| 247 |
+
count_desc = char_counts
|
| 248 |
+
if count_desc:
|
| 249 |
+
count_desc += ","
|
| 250 |
+
|
| 251 |
+
char_desc = ""
|
| 252 |
+
for c in chars_in_prompt:
|
| 253 |
+
char = chars_in_prompt[c]
|
| 254 |
+
char_info = char.get("info", {})
|
| 255 |
+
if char_info:
|
| 256 |
+
char_desc += char.get("short_prompt") + ","
|
| 257 |
+
continue
|
| 258 |
+
char_desc += char.get("name", char.get("key")) + ","
|
| 259 |
+
if len(chars_in_prompt) > 1:
|
| 260 |
+
char_desc = char_desc + "side-by-side,"
|
| 261 |
+
|
| 262 |
+
style_prompt = []
|
| 263 |
+
for key in styles:
|
| 264 |
+
#check if style need to be expanded
|
| 265 |
+
if styles[key]:
|
| 266 |
+
style_prompt.append(styles[key])
|
| 267 |
+
|
| 268 |
+
style_prompt = ",".join(style_prompt)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if extract_danbooru_tags:
|
| 272 |
+
tags, _ = self.create_danbooru_tags(prompt, 0.8)
|
| 273 |
+
if tags:
|
| 274 |
+
prompt += "," + tags + ","
|
| 275 |
+
prompt = count_desc + char_desc + prompt + style_prompt
|
| 276 |
+
return prompt
|
| 277 |
+
|
| 278 |
+
def compose_prompt_pony(self, prompt, chars_in_prompt, char_counts, styles, extract_danbooru_tags=False):
|
| 279 |
+
#char count
|
| 280 |
+
prompt = PromptGenerator.remove_child_related_content(prompt)
|
| 281 |
+
|
| 282 |
+
if "vibe" in styles:
|
| 283 |
+
vibe_config = styles["vibe"]
|
| 284 |
+
vibe_style = vibe_config["styles"][0]
|
| 285 |
+
vibe_prompt = vibe_config["prompt"]
|
| 286 |
+
prompt = vibe_prompt.replace("{prompt}", prompt).replace("{style}", vibe_style)
|
| 287 |
+
styles.pop("vibe")
|
| 288 |
+
|
| 289 |
+
new_prompt = ""
|
| 290 |
+
count_desc = char_counts
|
| 291 |
+
if count_desc:
|
| 292 |
+
count_desc += "{subject:" + count_desc + "},"
|
| 293 |
+
|
| 294 |
+
char_desc = ""
|
| 295 |
+
for c in chars_in_prompt:
|
| 296 |
+
char = chars_in_prompt[c]
|
| 297 |
+
char_info = char.get("info", {})
|
| 298 |
+
if char_info:
|
| 299 |
+
gender = char.get("gender", "")
|
| 300 |
+
char_desc += "{subject}, " + ("male" if gender == "boy" else "female") + "," + char.get("long_prompt") + " \n "
|
| 301 |
+
continue
|
| 302 |
+
char_desc += char.get("name", char.get("key")) + " "
|
| 303 |
+
if char_desc:
|
| 304 |
+
char_desc = "\n" + char_desc
|
| 305 |
+
|
| 306 |
+
style_prompt = []
|
| 307 |
+
for key in styles:
|
| 308 |
+
#check if style need to be expanded
|
| 309 |
+
if styles[key]:
|
| 310 |
+
style_prompt.append(styles[key])
|
| 311 |
+
|
| 312 |
+
style_prompt = ",".join(style_prompt)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if extract_danbooru_tags:
|
| 316 |
+
tags, _ = self.create_danbooru_tags(prompt, 0.8)
|
| 317 |
+
if tags:
|
| 318 |
+
prompt += "," + tags + ","
|
| 319 |
+
prompt = count_desc + prompt + style_prompt + char_desc
|
| 320 |
+
return prompt
|
| 321 |
+
|
| 322 |
+
def enhance_prompt(self, prompt, method="pony", styles={}, extract_danbooru_tags=True):
|
| 323 |
+
chars_in_prompt, char_counts, replaced_prompt = PromptGenerator.extract_prompt_elements(prompt)
|
| 324 |
+
new_prompt = prompt
|
| 325 |
+
if method == "il":
|
| 326 |
+
new_prompt = self.compose_prompt_il(replaced_prompt, chars_in_prompt, char_counts,styles, extract_danbooru_tags)
|
| 327 |
+
elif method == "pony":
|
| 328 |
+
new_prompt = self.compose_prompt_pony(replaced_prompt, chars_in_prompt, char_counts,styles, extract_danbooru_tags)
|
| 329 |
+
return new_prompt
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
prompt_generator = None
|
| 333 |
+
|
| 334 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 335 |
+
if randomize_seed:
|
| 336 |
+
seed = random.randint(0, MAX_SEED)
|
| 337 |
+
return seed
|
| 338 |
+
|
| 339 |
+
def common_upscale(samples: torch.Tensor, width: int, height: int, upscale_method: str) -> torch.Tensor:
|
| 340 |
+
return torch.nn.functional.interpolate(samples, size=(height, width), mode=upscale_method)
|
| 341 |
+
|
| 342 |
+
def upscale(samples: torch.Tensor, upscale_method: str, scale_by: float) -> torch.Tensor:
|
| 343 |
+
width = round(samples.shape[3] * scale_by)
|
| 344 |
+
height = round(samples.shape[2] * scale_by)
|
| 345 |
+
return common_upscale(samples, width, height, upscale_method)
|
| 346 |
+
|
| 347 |
+
@spaces.GPU(enable_queue=True)
|
| 348 |
+
def generate_image(model_id, prompt, negative_prompt, width, height, styles={}, scheduler_name=None, use_hd=True, cfg=7.5, steps=30, seed=0, options={}, progress=gr.Progress(track_tqdm=True)):
|
| 349 |
+
# Clear GPU memory
|
| 350 |
+
torch.cuda.empty_cache()
|
| 351 |
+
model = mdoel_mgr.get_model_set_style(model_id)
|
| 352 |
+
model_name = styles.get("vibe",{}).get("key", "realism")
|
| 353 |
+
print(model.config)
|
| 354 |
+
prediction_type = model.config.get("prediction_type", "epsilon")
|
| 355 |
+
# Start the process
|
| 356 |
+
pipe = model.pipe
|
| 357 |
+
pipe.scheduler.config.prediction_type = prediction_type
|
| 358 |
+
print(pipe.scheduler.config.prediction_type)
|
| 359 |
+
samplers = {
|
| 360 |
+
"Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
|
| 361 |
+
"DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True),
|
| 362 |
+
"DPM2 a": DPMSolverMultistepScheduler.from_config(pipe.scheduler.config),
|
| 363 |
+
"DPM++ SDE": DPMSolverSDEScheduler.from_config(pipe.scheduler.config),
|
| 364 |
+
"DPM++ 2M SDE": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_2m=True),
|
| 365 |
+
"DPM++ 2S a": DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_2s=True),
|
| 366 |
+
"NOISE": noise_scheduler,
|
| 367 |
+
}
|
| 368 |
+
prompt_analysis_keys = {"is_porn_involved": "boolean", "is_underage_involved": "boolean", "is_non_human_animal_involved": "boolean", "is_revenge_porn": "boolean", "detail_improved_prompt_in_natural_language": "string"}
|
| 369 |
+
prompt_analysis_result = llm_analyzer.analyze(f"Analyze the following prompt, determine if it contains porn instruction, underage characters, sexual interactino with non-human animals, revenge porn instruction, and generate a JSON with the given format: {json.dumps(prompt_analysis_keys)}\nPrompt: {prompt}", prompt_analysis_keys) or {}
|
| 370 |
+
print("options", options)
|
| 371 |
+
print("prompt_analysis_result", prompt_analysis_result)
|
| 372 |
+
pipe.to(device)
|
| 373 |
+
#pipe.enable_xformers_memory_efficient_attention()
|
| 374 |
+
config = model.config
|
| 375 |
+
total_negative_prompt = config.get("negative_prompt", "")
|
| 376 |
+
steps = steps or config.get("inference_steps", 30)
|
| 377 |
+
guidance_scale = cfg or config.get("guidance_scale", 6)
|
| 378 |
+
width = width or config.get("width", 512)
|
| 379 |
+
height = height or config.get("height", 512)
|
| 380 |
+
if not scheduler_name:
|
| 381 |
+
scheduler_name = config.get("sampler", "")
|
| 382 |
+
scheduler = samplers.get(scheduler_name, None)
|
| 383 |
+
else:
|
| 384 |
+
scheduler = samplers.get(scheduler_name, None)
|
| 385 |
+
|
| 386 |
+
prompt_str = prompt_generator.enhance_prompt(prompt,method=model.config.get("model_version","pony"),styles=styles, extract_danbooru_tags=False) #!Test!True
|
| 387 |
+
if config.get("upsample_prompt", True):
|
| 388 |
+
prompt_str = prompt_str + "," + prompt_analysis_result.get("detail_improved_prompt_in_natural_language", "")
|
| 389 |
+
|
| 390 |
+
prompt_str = config.get("prompt", "{prompt}").replace("{prompt}", prompt_str)
|
| 391 |
+
|
| 392 |
+
seed = seed or int(randomize_seed_fn(seed, True))
|
| 393 |
+
generator = torch.Generator(pipe.device).manual_seed(seed)
|
| 394 |
+
total_negative_prompt = total_negative_prompt + negative_prompt + vibes.get(styles.get("vibe", ""), {}).get("negative_prompt", "")
|
| 395 |
+
|
| 396 |
+
compel = Compel(
|
| 397 |
+
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
|
| 398 |
+
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
|
| 399 |
+
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
|
| 400 |
+
requires_pooled=[False, True],
|
| 401 |
+
truncate_long_prompts=False
|
| 402 |
+
)
|
| 403 |
+
use_hd = options.get("use_hd", model.config.get("use_hd", True))
|
| 404 |
+
if use_hd == True or options.get("use_hd", False):
|
| 405 |
+
apply_hidiffusion(pipe, apply_window_attn=True)
|
| 406 |
+
|
| 407 |
+
'''
|
| 408 |
+
conditioning, pooled = compel(prompt_str) # get weighted embeddings for prompt
|
| 409 |
+
neg_cond, neg_pooled = compel(total_negative_prompt)
|
| 410 |
+
|
| 411 |
+
print(conditioning.dtype) # Output: torch.bfloat16
|
| 412 |
+
print(neg_cond.dtype) # Output: torch.bfloat16
|
| 413 |
+
# embeddings for negative prompt
|
| 414 |
+
# Ensure the embedding sequences are the same length (pad if needed)
|
| 415 |
+
[conditioning, neg_cond] = compel.pad_conditioning_tensors_to_same_length([conditioning, neg_cond])
|
| 416 |
+
'''
|
| 417 |
+
|
| 418 |
+
#prompt = helper.get_embed_new(prompt, pipe, compel, only_convert_string=True)
|
| 419 |
+
#negative_prompt = helper.get_embed_new(total_negative_prompt, pipe, compel, only_convert_string=True)
|
| 420 |
+
#conditioning, pooled = compel([prompt, negative_prompt])
|
| 421 |
+
print(prompt_str)
|
| 422 |
+
if scheduler:
|
| 423 |
+
pipe.scheduler = scheduler
|
| 424 |
+
print(f"Generating SDXL, scheduler={scheduler_name}, usd_hd={use_hd}")
|
| 425 |
+
images = pipe(
|
| 426 |
+
#prompt_embeds=conditioning, pooled_prompt_embeds=pooled,
|
| 427 |
+
#negative_prompt_embeds=neg_cond, negative_pooled_prompt_embeds=neg_pooled,
|
| 428 |
+
|
| 429 |
+
prompt=prompt_str, negative_prompt=total_negative_prompt,
|
| 430 |
+
|
| 431 |
+
#prompt_embeds=conditioning[0:1],
|
| 432 |
+
#pooled_prompt_embeds=pooled[0:1],
|
| 433 |
+
#negative_prompt_embeds=conditioning[1:2],
|
| 434 |
+
#negative_pooled_prompt_embeds=pooled[1:2],
|
| 435 |
+
width=width, height=height, guidance_scale=cfg, num_inference_steps=steps,
|
| 436 |
+
num_images_per_prompt=1, generator=generator,eta=1.0,
|
| 437 |
+
#upscaling
|
| 438 |
+
#output_type="latent",
|
| 439 |
+
).images
|
| 440 |
+
|
| 441 |
+
#upscaling
|
| 442 |
+
'''
|
| 443 |
+
upscale_by = 1.5
|
| 444 |
+
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
|
| 445 |
+
upscaled_latents = upscale(images, "nearest-exact", upscale_by)
|
| 446 |
+
images = upscaler_pipe(
|
| 447 |
+
prompt=prompt_str,
|
| 448 |
+
negative_prompt=total_negative_prompt,
|
| 449 |
+
image=upscaled_latents,
|
| 450 |
+
guidance_scale=cfg,
|
| 451 |
+
num_inference_steps=steps,
|
| 452 |
+
strength=0.5,
|
| 453 |
+
generator=generator,
|
| 454 |
+
output_type="pil",
|
| 455 |
+
).images
|
| 456 |
+
'''
|
| 457 |
+
##
|
| 458 |
+
|
| 459 |
+
webp_images = []
|
| 460 |
+
nsfw_flags = ["a","h","n","p","s"]
|
| 461 |
+
for i in images:
|
| 462 |
+
flags = ""
|
| 463 |
+
print(i)
|
| 464 |
+
result = age_predictor.predict(i, 0.6)
|
| 465 |
+
print(result)
|
| 466 |
+
underage_flag = "y" if result.get("is_minor", False) else "n"
|
| 467 |
+
prompt_underage_flag = "y" if prompt_analysis_result.get("is_underage_involved", False) else "n"
|
| 468 |
+
nsfw_flag = ""
|
| 469 |
+
bestiality_flag = "n"
|
| 470 |
+
revenge_porn_flag = "n"
|
| 471 |
+
|
| 472 |
+
if prompt_analysis_result.get("is_porn_involved", False):
|
| 473 |
+
nsfw_flag = "y"
|
| 474 |
+
else:
|
| 475 |
+
nsfw_flag = "n"
|
| 476 |
+
if prompt_analysis_result.get("is_non_human_animal_involved", False):
|
| 477 |
+
if nsfw_flag == "y":
|
| 478 |
+
bestiality_flag = "y"
|
| 479 |
+
else:
|
| 480 |
+
bestiality_flag = "n"
|
| 481 |
+
else:
|
| 482 |
+
bestiality_flag = "n"
|
| 483 |
+
if prompt_analysis_result.get("is_revenge_porn", False):
|
| 484 |
+
revenge_porn_flag = "y"
|
| 485 |
+
else:
|
| 486 |
+
revenge_porn_flag = "n"
|
| 487 |
+
watermark = options.get("watermark", {})
|
| 488 |
+
file_uid = str(uuid.uuid4())
|
| 489 |
+
# Generate UTC timestamp
|
| 490 |
+
file_publish_time = int(datetime.now(timezone.utc).timestamp())
|
| 491 |
+
image_id = (hex(file_publish_time)[2:] + file_uid[-2:]).lower()
|
| 492 |
+
filename = f"hf_upif_{file_uid}_{file_publish_time}_{image_id}"
|
| 493 |
+
# image_id= hex of publish_time + last 4 char of file uid
|
| 494 |
+
should_publish = False
|
| 495 |
+
model_underage_check = model.config.get("underage_check", ["Teenager 13-20", "Child 0-12"])
|
| 496 |
+
underage_flag1 = result.get("age_group", "") in model_underage_check
|
| 497 |
+
if underage_flag == 'y':
|
| 498 |
+
if nsfw_flag == 'n':
|
| 499 |
+
underage_flag = 'n'
|
| 500 |
+
elif len(model_underage_check) < 2 and result.get("confidence", 0) > 0.4:
|
| 501 |
+
underage_flag = 'y' if underage_flag1 else 'n'
|
| 502 |
+
|
| 503 |
+
if bestiality_flag == "n" and underage_flag == "n" and prompt_underage_flag == "n":
|
| 504 |
+
if watermark and watermark.get("url", ""):
|
| 505 |
+
watermark_url = watermark["url"].replace("{image_id}", image_id).replace("{seed}", str(seed)).replace("{prompt}", prompt)
|
| 506 |
+
i = generate_watermark(original_image=i, qr_code_url=watermark_url)
|
| 507 |
+
should_publish = True
|
| 508 |
+
#nsfw_result = nsfw_classifier.predict(i)
|
| 509 |
+
#print(nsfw_result)
|
| 510 |
+
#nsfw_lvl = nsfw_result.get("category", 2)
|
| 511 |
+
|
| 512 |
+
#flags += nsfw_flags[nsfw_ lvl]
|
| 513 |
+
i = save_image(i, filename, f"{underage_flag}{nsfw_flag}{bestiality_flag}{revenge_porn_flag}{prompt_underage_flag}",
|
| 514 |
+
{
|
| 515 |
+
"seed": seed,
|
| 516 |
+
"domain": options.get("domain", ""),
|
| 517 |
+
"model_id": model.config.get("model_id", ""),
|
| 518 |
+
"model_name": model_name,
|
| 519 |
+
"prompt": prompt,
|
| 520 |
+
"publish_time": file_publish_time,
|
| 521 |
+
"file_uid": file_uid,
|
| 522 |
+
"image_id": image_id,
|
| 523 |
+
"underage_flag": underage_flag,
|
| 524 |
+
"nsfw_flag": nsfw_flag,
|
| 525 |
+
"bestiality_flag": bestiality_flag,
|
| 526 |
+
"revenge_porn_flag": revenge_porn_flag,
|
| 527 |
+
"prompt_underage_flag": prompt_underage_flag,
|
| 528 |
+
}
|
| 529 |
+
)
|
| 530 |
+
webp_images.append({
|
| 531 |
+
"image": i,
|
| 532 |
+
"should_publish": should_publish
|
| 533 |
+
})
|
| 534 |
+
return webp_images
|
| 535 |
+
|
| 536 |
+
auth = None
|
| 537 |
+
|
| 538 |
+
@spaces.GPU(enable_queue=True)
|
| 539 |
+
def generate_image_with_ipa(model_id, image, prompt, negative_prompt, width, height, styles={}, scheduler=None, face_strength=7.5, likeness_strength=0.1, steps=30, seed=0, options={}, progress=gr.Progress(track_tqdm=True)):
|
| 540 |
+
# Clear GPU memory
|
| 541 |
+
torch.cuda.empty_cache()
|
| 542 |
+
model = mdoel_mgr.get_model_set_style(model_id)
|
| 543 |
+
underage_check = age_predictor.predict(image, 0.6)
|
| 544 |
+
if underage_check.get("is_minor", False):
|
| 545 |
+
raise Exception("Uploaded image contains inappropriate content")
|
| 546 |
+
|
| 547 |
+
prediction_type = model.config.get("prediction_type", "epsilon")
|
| 548 |
+
# Start the process
|
| 549 |
+
pipe = model.pipe
|
| 550 |
+
pipe.scheduler.config.prediction_type = prediction_type
|
| 551 |
+
print(pipe.scheduler.config.prediction_type)
|
| 552 |
+
samplers = {
|
| 553 |
+
"Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
|
| 554 |
+
"DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True),
|
| 555 |
+
"DPM2 a": DPMSolverMultistepScheduler.from_config(pipe.scheduler.config),
|
| 556 |
+
"DPM++ SDE": DPMSolverSDEScheduler.from_config(pipe.scheduler.config),
|
| 557 |
+
"DPM++ 2M SDE": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_2m=True),
|
| 558 |
+
"DPM++ 2S a": DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_2s=True),
|
| 559 |
+
"NOISE": noise_scheduler,
|
| 560 |
+
}
|
| 561 |
+
pipe.to(device)
|
| 562 |
+
#pipe.enable_xformers_memory_efficient_attention()
|
| 563 |
+
config = model.config
|
| 564 |
+
total_negative_prompt = config.get("negative_prompt", "")
|
| 565 |
+
steps = steps or config.get("inference_steps", 30)
|
| 566 |
+
likeness_strength = likeness_strength or config.get("likeness_strength", 0.1)
|
| 567 |
+
face_strength = face_strength or config.get("face_strength", 7.0)
|
| 568 |
+
prompt_analysis_keys = {"is_porn_involved": "boolean", "is_underage_involved": "boolean", "is_non_human_animal_involved": "boolean", "is_revenge_porn": "boolean", "detail_improved_prompt_in_natural_language": "string"}
|
| 569 |
+
prompt_analysis_result = llm_analyzer.analyze(f"Analyze the following prompt, determine if it contains porn instruction, underage characters, sexual interactino with non-human animals, revenge porn instruction, and generate a JSON with the given format: {json.dumps(prompt_analysis_keys)}\nPrompt: {prompt}", prompt_analysis_keys) or {}
|
| 570 |
+
|
| 571 |
+
width = width or config.get("width", 512)
|
| 572 |
+
height = height or config.get("height", 512)
|
| 573 |
+
if not scheduler:
|
| 574 |
+
scheduler = config.get("sampler", "")
|
| 575 |
+
scheduler = samplers.get(scheduler, None)
|
| 576 |
+
else:
|
| 577 |
+
scheduler = samplers.get(scheduler, None)
|
| 578 |
+
|
| 579 |
+
prompt_str = prompt_generator.enhance_prompt(prompt,method=model.config.get("model_version","pony"),styles=styles, extract_danbooru_tags=False) #!Test!True
|
| 580 |
+
prompt_str = config.get("prompt", "{prompt}").replace("{prompt}", prompt_str)
|
| 581 |
+
prompt_str = prompt_str + "," + prompt_analysis_result.get("detail_improved_prompt_in_natural_language", "")
|
| 582 |
+
seed = seed or int(randomize_seed_fn(seed, True))
|
| 583 |
+
generator = torch.Generator(pipe.device).manual_seed(seed)
|
| 584 |
+
total_negative_prompt = total_negative_prompt + negative_prompt + vibes.get(styles.get("vibe", ""), {}).get("negative_prompt", "")
|
| 585 |
+
|
| 586 |
+
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
| 587 |
+
app.prepare(ctx_id=0, det_size=(512, 512))
|
| 588 |
+
|
| 589 |
+
faceid_all_embeds = []
|
| 590 |
+
|
| 591 |
+
face = cv2.imread(image)
|
| 592 |
+
faces = app.get(face)
|
| 593 |
+
faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
| 594 |
+
faceid_all_embeds.append(faceid_embed)
|
| 595 |
+
|
| 596 |
+
average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
|
| 597 |
+
|
| 598 |
+
#prompt = helper.get_embed_new(prompt, pipe, compel, only_convert_string=True)
|
| 599 |
+
#negative_prompt = helper.get_embed_new(total_negative_prompt, pipe, compel, only_convert_string=True)
|
| 600 |
+
#conditioning, pooled = compel([prompt, negative_prompt])
|
| 601 |
+
print(prompt_str)
|
| 602 |
+
if scheduler:
|
| 603 |
+
pipe.scheduler = scheduler
|
| 604 |
+
|
| 605 |
+
print(f"Generating SDXL, scheduler={scheduler}, usd_hd={use_hd}")
|
| 606 |
+
images = model.ip_model.generate(
|
| 607 |
+
prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,num_samples=1,num_images_per_prompt=1,
|
| 608 |
+
scale=likeness_strength, width=width, height=height, guidance_scale=face_strength, num_inference_steps=steps
|
| 609 |
+
)
|
| 610 |
+
#upscaling
|
| 611 |
+
'''
|
| 612 |
+
upscale_by = 1.5
|
| 613 |
+
upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
|
| 614 |
+
upscaled_latents = upscale(images, "nearest-exact", upscale_by)
|
| 615 |
+
images = upscaler_pipe(
|
| 616 |
+
prompt=prompt_str,
|
| 617 |
+
negative_prompt=total_negative_prompt,
|
| 618 |
+
image=upscaled_latents,
|
| 619 |
+
guidance_scale=cfg,
|
| 620 |
+
num_inference_steps=steps,
|
| 621 |
+
strength=0.5,
|
| 622 |
+
generator=generator,
|
| 623 |
+
output_type="pil",
|
| 624 |
+
).images
|
| 625 |
+
'''
|
| 626 |
+
##
|
| 627 |
+
|
| 628 |
+
webp_images = []
|
| 629 |
+
for i in images:
|
| 630 |
+
print(i)
|
| 631 |
+
result = age_predictor.predict(i, 0.6)
|
| 632 |
+
print(result)
|
| 633 |
+
webp_images.append(helper.save_image(i, "fua" if result.get("is_minor", False) else "fip"))
|
| 634 |
+
#webp_images.append(helper.save_image(i))
|
| 635 |
+
return webp_images
|
| 636 |
+
|
| 637 |
+
mdoel_mgr = None
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
# Example usage in generate function
|
| 641 |
+
@spaces.GPU(enable_queue=True)
|
| 642 |
+
def generate(model_id, image, prompt, negative_prompt, scheduler, use_hd, face_strength, likeness_strength, width, height, publish, seed=0, progress=gr.Progress(track_tqdm=True)):
|
| 643 |
+
try:
|
| 644 |
+
if image is not None:
|
| 645 |
+
#save the image to file
|
| 646 |
+
image_path = f"temp_image_{time.time()}.jpg"
|
| 647 |
+
image = Image.fromarray(image)
|
| 648 |
+
image.save(image_path)
|
| 649 |
+
try:
|
| 650 |
+
return generate_image_with_ipa(model_id, image_path, prompt, negative_prompt, width, height, {"vibe":vibes[model_id]},
|
| 651 |
+
scheduler=None, face_strength=face_strength,
|
| 652 |
+
likeness_strength=likeness_strength, steps=30, seed=seed, options={"domain": "nsfwais.io"}, progress=progress)
|
| 653 |
+
finally:
|
| 654 |
+
os.remove(image_path)
|
| 655 |
+
else:
|
| 656 |
+
result = generate_image(model_id, prompt, negative_prompt, width, height, {"vibe":vibes[model_id]},
|
| 657 |
+
scheduler, use_hd, cfg=face_strength, steps=30, seed=seed, options={"domain": "nsfwais.io"}, progress=progress)
|
| 658 |
+
display_image_path = result[0]["image"]
|
| 659 |
+
if not os.path.exists(display_image_path):
|
| 660 |
+
raise FileNotFoundError(f"Image file not found: {display_image_path}")
|
| 661 |
+
|
| 662 |
+
# Open it with PIL
|
| 663 |
+
display_image = Image.open(display_image_path)
|
| 664 |
+
|
| 665 |
+
# (Optional) convert mode if needed, e.g. to RGB
|
| 666 |
+
# display_image = display_image.convert("RGB")
|
| 667 |
+
|
| 668 |
+
if publish:
|
| 669 |
+
try:
|
| 670 |
+
publish_result = utils.publish_url_sync(local_config.get("download_url",""), result[0]["image"], config.get("publish_url", ""))
|
| 671 |
+
print(publish_result)
|
| 672 |
+
except Exception as e:
|
| 673 |
+
print(e)
|
| 674 |
+
print(result)
|
| 675 |
+
|
| 676 |
+
return [display_image]
|
| 677 |
+
finally:
|
| 678 |
+
# Force CUDA memory cleanup after generation
|
| 679 |
+
torch.cuda.empty_cache()
|
| 680 |
+
generate.zerogpu=True
|
| 681 |
+
|
| 682 |
+
def ipa_image_gen_api(source_image, prompt, negative_prompt, styles, options, token, request: gr.Request, progress=gr.Progress(track_tqdm=True)):
|
| 683 |
+
#auth.check_auth("/image_gen_api", request, token)
|
| 684 |
+
print(styles)
|
| 685 |
+
if isinstance(styles, str):
|
| 686 |
+
styles = json.loads(styles)
|
| 687 |
+
if isinstance(options, str):
|
| 688 |
+
options = json.loads(options)
|
| 689 |
+
if source_image is not None:
|
| 690 |
+
source_image = Image.fromarray(source_image)
|
| 691 |
+
#save the image to file
|
| 692 |
+
image_path = f"temp_image_{time.time()}.jpg"
|
| 693 |
+
source_image.save(image_path)
|
| 694 |
+
try:
|
| 695 |
+
return image_gen_api_entry("ipa", image_path, prompt, negative_prompt, styles, options, token, mdoel_mgr, progress)
|
| 696 |
+
finally:
|
| 697 |
+
os.remove(image_path)
|
| 698 |
+
else:
|
| 699 |
+
return image_gen_api_entry("image", None, prompt, negative_prompt, styles, options, token, mdoel_mgr, progress)
|
| 700 |
+
ipa_image_gen_api.zerogpu=True
|
| 701 |
+
|
| 702 |
+
def image_gen_api(prompt, negative_prompt, styles, options, token, request: gr.Request, progress=gr.Progress(track_tqdm=True)):
|
| 703 |
+
#auth.check_auth("/image_gen_api", request, token)
|
| 704 |
+
print(styles)
|
| 705 |
+
if isinstance(styles, str):
|
| 706 |
+
styles = json.loads(styles)
|
| 707 |
+
if isinstance(options, str):
|
| 708 |
+
options = json.loads(options)
|
| 709 |
+
if "domain" not in options:
|
| 710 |
+
options["domain"] = "nsfwais.io"
|
| 711 |
+
return image_gen_api_entry("image", None, prompt, negative_prompt, styles, options, token, mdoel_mgr, progress)
|
| 712 |
+
image_gen_api.zerogpu=True
|
| 713 |
+
|
| 714 |
+
def image_gen_api_entry(mode, source_image, prompt, negative_prompt, styles, options, token, use_model_mgr:ModelManager, progress=gr.Progress(track_tqdm=True)):
|
| 715 |
+
size_2_width_height = {
|
| 716 |
+
"1_1": [[1280, 1280], [2048, 2048]], # Square (already multiple of 8)
|
| 717 |
+
"4_3": [[1200, 896], [2048, 1536]], # Twitter/X (1200 rounded to 1192, 900 to 896)
|
| 718 |
+
"16_9": [[1280, 720], [1920, 1088]], # Facebook (1280 is fine, 1080 rounded to 1088)
|
| 719 |
+
"4_5": [[1080, 1344], [2048, 2560]], # Instagram (1080 rounded to 1080, 1350 to 1344)
|
| 720 |
+
"9_16": [[720, 1280], [1080, 1920]], # Tiktok (720 is fine, 1280 is fine)
|
| 721 |
+
}
|
| 722 |
+
for s in styles:
|
| 723 |
+
if styles[s] == "Auto":
|
| 724 |
+
styles[s] = ""
|
| 725 |
+
vibe = styles.get("vibe","") or "realistic"
|
| 726 |
+
vibe_config = vibes.get(vibe, {})
|
| 727 |
+
if not vibe_config:
|
| 728 |
+
raise Exception(f"invalid vibe {vibe}")
|
| 729 |
+
style_name = vibe_config.get("models")[0]
|
| 730 |
+
styles["vibe"] = vibe_config
|
| 731 |
+
#lighting = styles.get("lighting", "")
|
| 732 |
+
#camera = styles.get("camera", "")
|
| 733 |
+
size = styles.get("size", "") or "4_3"
|
| 734 |
+
styles.pop("size",None)
|
| 735 |
+
steps = options.get("steps", 30)
|
| 736 |
+
seed = options.get("seed",0)
|
| 737 |
+
|
| 738 |
+
width, height = size_2_width_height.get(size)[0]
|
| 739 |
+
|
| 740 |
+
#prompt = prompt_generator.enhance_prompt(prompt)
|
| 741 |
+
#source_image = Image.fromarray(source_image)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
'''
|
| 745 |
+
chars = characters.get()
|
| 746 |
+
mentioned_chars = utils.extract_characters(prompt)
|
| 747 |
+
converted_chars = set()
|
| 748 |
+
for c in mentioned_chars:
|
| 749 |
+
char_info = chars.get(c["key"], {})
|
| 750 |
+
char_prompts = char_info.get("samplePrompts", [])
|
| 751 |
+
if char_prompts:
|
| 752 |
+
func = max
|
| 753 |
+
if c["key"] in converted_chars:
|
| 754 |
+
func = min
|
| 755 |
+
|
| 756 |
+
char_prompt = (lambda s: s)(func(char_prompts, key=len))
|
| 757 |
+
if char_prompt:
|
| 758 |
+
prompt = prompt.replace(c["raw"], char_prompt)
|
| 759 |
+
converted_chars.add(c["key"])
|
| 760 |
+
else:
|
| 761 |
+
print("char "+c["key"]+" info not found, skip interpretting")
|
| 762 |
+
'''
|
| 763 |
+
|
| 764 |
+
if mode == "image":
|
| 765 |
+
generated_image = generate_image(vibe, prompt, negative_prompt, width, height, styles, scheduler_name=None,use_hd=None, cfg=5, steps=steps, seed=seed, options=options, progress=progress)[0]
|
| 766 |
+
if generated_image.get("should_publish", False):
|
| 767 |
+
try:
|
| 768 |
+
result = utils.publish_url_sync(local_config.get("download_url",""), generated_image["image"], config.get("publish_url", ""), timeout=10)
|
| 769 |
+
if "url" in result:
|
| 770 |
+
return generated_image["image"], result["url"]
|
| 771 |
+
else:
|
| 772 |
+
print(f"error in result{result}")
|
| 773 |
+
except Exception as e:
|
| 774 |
+
print(e)
|
| 775 |
+
|
| 776 |
+
return generated_image["image"], ""
|
| 777 |
+
|
| 778 |
+
elif mode == "ipa":
|
| 779 |
+
face_strength = options.get("face_strength", 11)
|
| 780 |
+
likeness_strength = options.get("likeness_strength", 0.5)
|
| 781 |
+
return generate_image_with_ipa(vibe, source_image, prompt, negative_prompt, width, height, styles, scheduler=None, face_strength=face_strength, likeness_strength=likeness_strength, steps=steps, seed=seed, options=options, progress=progress)[0]
|
| 782 |
+
else:
|
| 783 |
+
raise Exception(f"invalid mode {mode}")
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def swap_to_gallery(images):
|
| 788 |
+
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
|
| 789 |
+
|
| 790 |
+
def remove_back_to_files():
|
| 791 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
|
| 792 |
+
css = '''
|
| 793 |
+
h1{margin-bottom: 0 !important}
|
| 794 |
+
'''
|
| 795 |
+
with gr.Blocks(css=css) as demo:
|
| 796 |
+
#gr.Markdown("# IP-Adapter-FaceID SDXL demo")
|
| 797 |
+
#gr.Markdown("A simple Demo for the [h94/IP-Adapter-FaceID SDXL model](https://huggingface.co/h94/IP-Adapter-FaceID) together with [Juggernaut XL v7](https://huggingface.co/stablediffusionapi/juggernaut-xl-v7). You should run this on at least 24 GB of VRAM.")
|
| 798 |
+
with gr.Row():
|
| 799 |
+
with gr.Column():
|
| 800 |
+
'''
|
| 801 |
+
files = gr.Files(
|
| 802 |
+
label="Drag 1 or more photos of your face",
|
| 803 |
+
file_types=["image"]
|
| 804 |
+
)
|
| 805 |
+
'''
|
| 806 |
+
upload_image = gr.Image(label="Upload an image")
|
| 807 |
+
#uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=250)
|
| 808 |
+
#with gr.Column(visible=False) as clear_button:
|
| 809 |
+
#remove_and_reupload = gr.ClearButton(value="Remove files and upload new ones", components=files, size="sm")
|
| 810 |
+
model = gr.Textbox(label="",
|
| 811 |
+
info="",
|
| 812 |
+
placeholder="",
|
| 813 |
+
value=config.get("default_model", ""))
|
| 814 |
+
|
| 815 |
+
prompt = gr.Textbox(label="Prompt",
|
| 816 |
+
info="Try something like 'a photo of a man/woman/person'",
|
| 817 |
+
placeholder="A photo of a man/woman/person ...",
|
| 818 |
+
lines=10,
|
| 819 |
+
value="")
|
| 820 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", info="What the model should NOT produce.",placeholder="low quality", value="worst quality, low quality")
|
| 821 |
+
# Text box for the width
|
| 822 |
+
width = gr.Number(value=config.get("width", 1280), label="Width", info="Width of the generated image")
|
| 823 |
+
# Text box for the height
|
| 824 |
+
height = gr.Number(value=config.get("height", 1280), label="Height", info="Height of the generated image")
|
| 825 |
+
scheduler = gr.Textbox(label="scheduler", info="",placeholder="low quality", value="Euler a")
|
| 826 |
+
use_hd = gr.Checkbox(label="use hidiffusion", value=config.get("use_hd", False))
|
| 827 |
+
publish = gr.Checkbox(label="publish", value=False, visible=config.get("publish_url", "") != "")
|
| 828 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", info="What the model should NOT produce.",placeholder="low quality", value="worst quality, low quality")
|
| 829 |
+
seed = gr.Number(label="Seed", info="Seed for the generated image", value=0)
|
| 830 |
+
style = "Photorealistic"
|
| 831 |
+
face_strength = gr.Slider(label="Prompt Strength", info="How much the written prompt weighs into the generated images.", value=7.5, step=0.1, minimum=0, maximum=15)
|
| 832 |
+
likeness_strength = gr.Slider(label="Photo Embedding Strength", info="How much your uploaded files weigh into the generated images.", value=1.0, step=0.1, minimum=0, maximum=5)
|
| 833 |
+
|
| 834 |
+
submit = gr.Button("Submit", variant="primary")
|
| 835 |
+
submit1 = gr.Button("image_gen_api", variant="primary", visible=False)
|
| 836 |
+
submit2 = gr.Button("image_gen_api_ipa", variant="primary", visible=False)
|
| 837 |
+
with gr.Column():
|
| 838 |
+
gallery = gr.Gallery(label="Generated Images")
|
| 839 |
+
#files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
|
| 840 |
+
#remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
|
| 841 |
+
submit.click(fn=generate,
|
| 842 |
+
inputs=[model, upload_image,prompt,negative_prompt,scheduler, use_hd, face_strength, likeness_strength, width, height, publish, seed],
|
| 843 |
+
outputs=gallery)
|
| 844 |
+
|
| 845 |
+
submit1.click(fn=image_gen_api,
|
| 846 |
+
inputs=[prompt,negative_prompt, negative_prompt, negative_prompt, negative_prompt],
|
| 847 |
+
outputs=[gr.Image(visible=False), gr.Textbox(visible=False)])
|
| 848 |
+
submit2.click(fn=ipa_image_gen_api,
|
| 849 |
+
inputs=[upload_image, prompt,negative_prompt, negative_prompt, negative_prompt, negative_prompt],
|
| 850 |
+
outputs=gr.Image(visible=False))
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
# gr.Markdown("This demo includes extra features to mitigate the implicit bias of the model and prevent explicit usage of it to generate content with faces of people, including third parties, that is not safe for all audiences, including naked or semi-naked people.")
|
| 854 |
+
|
| 855 |
+
if __name__ == "__main__":
|
| 856 |
+
# Read a local.json for local config
|
| 857 |
+
|
| 858 |
+
auth = AuthHelper()
|
| 859 |
+
download_for_imagegen.download_models()
|
| 860 |
+
prompt_generator = PromptGenerator()
|
| 861 |
+
mdoel_mgr = ModelManager(config, vibes, lazy_load=local_config.get("lazy_load", False), max_loaded_models=local_config.get("max_loaded_models", 1))
|
| 862 |
+
print("launching")
|
| 863 |
+
demo.launch()
|
requirements.txt
CHANGED
|
@@ -3,7 +3,6 @@
|
|
| 3 |
torch==2.4.0
|
| 4 |
|
| 5 |
# Core CUDA stack (these three versions are a matched set)
|
| 6 |
-
spaces==0.30.0
|
| 7 |
pydantic==2.10.6
|
| 8 |
insightface==0.7.3
|
| 9 |
diffusers
|
|
|
|
| 3 |
torch==2.4.0
|
| 4 |
|
| 5 |
# Core CUDA stack (these three versions are a matched set)
|
|
|
|
| 6 |
pydantic==2.10.6
|
| 7 |
insightface==0.7.3
|
| 8 |
diffusers
|