Spaces:
Runtime error
Runtime error
File size: 3,013 Bytes
a7b5c39 29b2cf7 a7b5c39 fae8c2a 3b3fc0f fae8c2a ede0f90 fae8c2a 29b2cf7 fae8c2a 29b2cf7 fae8c2a 29b2cf7 ede0f90 29b2cf7 ede0f90 29b2cf7 fae8c2a a89aa16 29b2cf7 fae8c2a a7b5c39 fae8c2a a7b5c39 fae8c2a a7b5c39 fae8c2a a7b5c39 fae8c2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from click import command
import streamlit as st
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import os
from authtoken import auth_token
from datasets import load_dataset
from PIL import Image
import re
import os
import requests
st.title('Stable Diffusion (AI that can imagine)')
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
# word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
# word_list = word_list_dataset["train"]['text']
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
torch.backends.cudnn.benchmark = True
is_gpu_busy = False
def infer(prompt):
global is_gpu_busy
samples = 4
steps = 50
scale = 7.5
#When running locally you can also remove this filter
# for filter in word_list:
# if re.search(rf"\b{filter}\b", prompt):
# raise gr.Error("Unsafe content found. Please try again with different prompts.")
generator = torch.Generator(device=device).manual_seed(seed)
#print("Is GPU busy? ", is_gpu_busy)
images = []
if(not is_gpu_busy):
is_gpu_busy = True
images_list = pipe(
[prompt] * samples,
num_inference_steps=steps,
guidance_scale=scale,
generator=generator,
)
is_gpu_busy = False
safe_image = Image.open(r"unsafe.png")
for i, image in enumerate(images_list["sample"]):
if(images_list["nsfw_content_detected"][i]):
images.append(safe_image)
else:
images.append(image)
# else:
# url = os.getenv('JAX_BACKEND_URL')
# payload = {'prompt': prompt}
# images_request = requests.post(url, json = payload)
# for image in images_request.json()["images"]:
# image_b64 = (f"data:image/jpeg;base64,{image}")
# images.append(image_b64)
return images
prompt=st.text_input('Enter your prompt')
true=st.button('Generate')
if true==True:
a=infer(prompt)
for i in a:
st.image(a)
# from diffusers import StableDiffusionPipeline
# st.session_state.ac=True
# if st.session_state.ac==True:
# os.system('pip3 install torch torchvision torchaudio')
# import torch
# from torch import autocast
# st.session_state.ac=False
# st.title('Stable Diffusion (AI that can imagine)')
# prompt=st.text_input('Enter the prompt')
# modelid = "CompVis/stable-diffusion-v1-4"
# device = "cuda"
# pipe = StableDiffusionPipeline.from_pretrained(modelid, revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token)
# pipe.to(device)
# def generate():
# with autocast(device):
# image = pipe(prompt.get(), guidance_scale=8.5)["sample"][0]
# image.save(prompt+'.png')
# img = st.image(prompt+'.png')
# st.button('Generate',command=generate) |