alexander-potemkin's picture
Update app.py
12cf82c verified
import gradio as gr
import spaces
import torch
import random
from diffusers import DiffusionPipeline, AutoPipelineForText2Image, FluxPipeline, StableDiffusionXLPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device, torch.cuda.is_available())
# image_model = "Heartsync/Flux-NSFW-uncensored"
image_model = "andro-flock/LUSTIFY-SDXL-NSFW-checkpoint-v2-0-INPAINTING"
hf_token = os.getenv("HF_TOKEN", None)
@spaces.GPU
def generate(prompt, negative_prompt, model=image_model):
print("Generating image...")
pipe = None
if negative_prompt == "" or negative_prompt == None:
negative_prompt = "ugly, deformed, disfigured, poor quality, low resolution"
if model == 'enhanceaiteam/Flux-uncensored-v2':
# Load the base model
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=hf_token).to('cuda')
# Load the uncensored LoRA weights
pipe.load_lora_weights('enhanceaiteam/Flux-uncensored-v2', weight_name='lora.safetensors')
elif model == "Heartsync/Flux-NSFW-uncensored":
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=hf_token).to('cuda')
pipe.load_lora_weights(
'Heartsync/Flux-NSFW-uncensored',
weight_name='lora.safetensors',
adapter_name="uncensored"
)
elif model == "andro-flock/LUSTIFY-SDXL-NSFW-checkpoint-v2-0-INPAINTING":
pipe = StableDiffusionPipeline.from_pretrained(
model,
torch_dtype=torch.float16,
).to(device)
else:
pipe = DiffusionPipeline.from_pretrained(
model,
torch_dtype=torch.float16
)
pipe.to(device)
print(f"Using model: {model}")
return pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=7.0,
num_inference_steps=30,
width=1024,
height=1024,
).images
app = gr.Interface(
fn=generate,
inputs=[
gr.Text(label="Prompt"),
gr.Text("", label="Negative Prompt"),
gr.Dropdown(
["Heartsync/Flux-NSFW-uncensored", "enhanceaiteam/Flux-uncensored-v2", "Heartsync/NSFW-Uncensored", "UnfilteredAI/NSFW-gen-v2", "andro-flock/LUSTIFY-SDXL-NSFW-checkpoint-v2-0-INPAINTING"], label="Image model", info="Select the image model:"
),
],
outputs=gr.Gallery(),
)
app.launch(show_api=True)