|
|
import spaces |
|
|
import os |
|
|
import subprocess |
|
|
|
|
|
|
|
|
''' |
|
|
subprocess.run(['sh', './conda.sh']) |
|
|
|
|
|
import sys |
|
|
conda_prefix = os.path.expanduser("~/miniconda3") |
|
|
conda_bin = os.path.join(conda_prefix, "bin") |
|
|
|
|
|
# Add Conda's bin directory to your PATH |
|
|
os.environ["PATH"] = conda_bin + os.pathsep + os.environ["PATH"] |
|
|
|
|
|
# Activate the base environment (adjust if needed) |
|
|
os.system(f'{conda_bin}/conda init --all') |
|
|
os.system(f'{conda_bin}/conda activate base') |
|
|
os.system(f'{conda_bin}/conda install nvidia/label/cudnn-9.3.0::cudnn') |
|
|
''' |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import paramiko |
|
|
from image_gen_aux import UpscaleWithModel |
|
|
import cyper |
|
|
from PIL import Image |
|
|
os.environ['JAX_PLATFORMS'] = 'cpu' |
|
|
import random |
|
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00" |
|
|
import keras |
|
|
import keras_hub |
|
|
import torch |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False |
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False |
|
|
torch.backends.cudnn.allow_tf32 = False |
|
|
torch.backends.cudnn.deterministic = False |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
torch.set_float32_matmul_precision("highest") |
|
|
|
|
|
upscaler_2 = None |
|
|
text_to_image = None |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def load_model(): |
|
|
global upscaler_2 |
|
|
upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device) |
|
|
global text_to_image |
|
|
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset( |
|
|
"stable_diffusion_3_medium", width=768, height=768, dtype="bfloat16" |
|
|
) |
|
|
return text_to_image |
|
|
|
|
|
code = r''' |
|
|
import paramiko |
|
|
import os |
|
|
FTP_HOST = '1ink.us' |
|
|
FTP_USER = 'ford442' |
|
|
FTP_PASS = os.getenv("FTP_PASS") |
|
|
FTP_DIR = '1ink.us/stable_diff/' |
|
|
def upload_to_ftp(filename): |
|
|
try: |
|
|
transport = paramiko.Transport((FTP_HOST, 22)) |
|
|
destination_path=FTP_DIR+filename |
|
|
transport.connect(username = FTP_USER, password = FTP_PASS) |
|
|
sftp = paramiko.SFTPClient.from_transport(transport) |
|
|
sftp.put(filename, destination_path) |
|
|
sftp.close() |
|
|
transport.close() |
|
|
print(f"Uploaded {filename} to FTP server") |
|
|
except Exception as e: |
|
|
print(f"FTP upload error: {e}") |
|
|
''' |
|
|
|
|
|
pyx = cyper.inline(code, fast_indexing=True, directives=dict(boundscheck=False, wraparound=False, language_level=3)) |
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
MAX_IMAGE_SIZE = 4096 |
|
|
|
|
|
@spaces.GPU(duration=40) |
|
|
def infer_30( |
|
|
prompt, |
|
|
negative_prompt, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
global text_to_image |
|
|
if text_to_image is None: |
|
|
text_to_image = load_model() |
|
|
os.environ['JAX_PLATFORMS'] = 'gpu' |
|
|
os.environ['KERAS_BACKEND'] = 'jax' |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
sd_image = text_to_image.generate( |
|
|
prompt, |
|
|
num_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
seed=seed |
|
|
) |
|
|
print('-- got image --') |
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
sd35_path = f"sd3keras_{timestamp}.png" |
|
|
sd_image.save(sd35_path,optimize=False,compress_level=0) |
|
|
pyx.upload_to_ftp(sd35_path) |
|
|
with torch.no_grad(): |
|
|
upscale2 = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256) |
|
|
print('-- got upscaled image --') |
|
|
downscale2 = upscale2.resize((upscale2.width // 4, upscale2.height // 4),Image.LANCZOS) |
|
|
upscale_path = f"sd3keras_upscale_{timestamp}.png" |
|
|
downscale2.save(upscale_path,optimize=False,compress_level=0) |
|
|
pyx.upload_to_ftp(upscale_path) |
|
|
return sd_image, prompt |
|
|
|
|
|
@spaces.GPU(duration=70) |
|
|
def infer_60( |
|
|
prompt, |
|
|
negative_prompt, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
global text_to_image |
|
|
if text_to_image is None: |
|
|
text_to_image = load_model() |
|
|
os.environ['JAX_PLATFORMS'] = 'gpu' |
|
|
os.environ['KERAS_BACKEND'] = 'jax' |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
sd_image = text_to_image.generate( |
|
|
prompt, |
|
|
num_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
seed=seed |
|
|
) |
|
|
print('-- got image --') |
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
sd35_path = f"sd3keras_{timestamp}.png" |
|
|
sd_image.save(sd35_path,optimize=False,compress_level=0) |
|
|
pyx.upload_to_ftp(sd35_path) |
|
|
with torch.no_grad(): |
|
|
upscale2 = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256) |
|
|
print('-- got upscaled image --') |
|
|
downscale2 = upscale2.resize((upscale2.width // 4, upscale2.height // 4),Image.LANCZOS) |
|
|
upscale_path = f"sd3keras_upscale_{timestamp}.png" |
|
|
downscale2.save(upscale_path,optimize=False,compress_level=0) |
|
|
pyx.upload_to_ftp(upscale_path) |
|
|
return sd_image, prompt |
|
|
|
|
|
@spaces.GPU(duration=100) |
|
|
def infer_90( |
|
|
prompt, |
|
|
negative_prompt, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
global text_to_image |
|
|
if text_to_image is None: |
|
|
text_to_image = load_model() |
|
|
os.environ['JAX_PLATFORMS'] = 'gpu' |
|
|
os.environ['KERAS_BACKEND'] = 'jax' |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
sd_image = text_to_image.generate( |
|
|
prompt, |
|
|
num_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
seed=seed |
|
|
) |
|
|
print('-- got image --') |
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
sd35_path = f"sd3keras_{timestamp}.png" |
|
|
sd_image.save(sd35_path,optimize=False,compress_level=0) |
|
|
pyx.upload_to_ftp(sd35_path) |
|
|
with torch.no_grad(): |
|
|
upscale2 = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256) |
|
|
print('-- got upscaled image --') |
|
|
downscale2 = upscale2.resize((upscale2.width // 4, upscale2.height // 4),Image.LANCZOS) |
|
|
upscale_path = f"sd3keras_upscale_{timestamp}.png" |
|
|
downscale2.save(upscale_path,optimize=False,compress_level=0) |
|
|
pyx.upload_to_ftp(upscale_path) |
|
|
return sd_image, prompt |
|
|
|
|
|
|
|
|
css = """ |
|
|
#col-container {margin: 0 auto;max-width: 640px;} |
|
|
body{background-color: blue;} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo: |
|
|
with gr.Column(elem_id="col-container"): |
|
|
gr.Markdown(" # StableDiffusion 3 Medium from Keras-hub") |
|
|
expanded_prompt_output = gr.Textbox(label="Prompt", lines=1) |
|
|
with gr.Row(): |
|
|
prompt = gr.Text( |
|
|
label="Prompt", |
|
|
show_label=False, |
|
|
max_lines=1, |
|
|
placeholder="Enter your prompt", |
|
|
container=False, |
|
|
) |
|
|
load_button = gr.Button("Load model", scale=0, variant="primary") |
|
|
run_button_30 = gr.Button("Run 30", scale=0, variant="primary") |
|
|
run_button_60 = gr.Button("Run 60", scale=0, variant="primary") |
|
|
run_button_90 = gr.Button("Run 90", scale=0, variant="primary") |
|
|
result = gr.Image(label="Result", show_label=False) |
|
|
with gr.Accordion("Advanced Settings", open=True): |
|
|
negative_prompt = gr.Text( |
|
|
label="Negative prompt", |
|
|
max_lines=1, |
|
|
placeholder="Enter a negative prompt", |
|
|
visible=True, |
|
|
value="bad anatomy, poorly drawn hands, distorted face, blurry, out of frame, low resolution, grainy, pixelated, disfigured, mutated, extra limbs, bad composition" |
|
|
) |
|
|
with gr.Row(): |
|
|
guidance_scale = gr.Slider( |
|
|
label="Guidance scale", |
|
|
minimum=0.0, |
|
|
maximum=30.0, |
|
|
step=0.1, |
|
|
value=4.2, |
|
|
) |
|
|
num_inference_steps = gr.Slider( |
|
|
label="Number of inference steps", |
|
|
minimum=1, |
|
|
maximum=500, |
|
|
step=1, |
|
|
value=50, |
|
|
) |
|
|
|
|
|
gr.on( |
|
|
triggers=[load_button.click], |
|
|
fn=load_model, |
|
|
inputs=[], |
|
|
outputs=[], |
|
|
) |
|
|
|
|
|
gr.on( |
|
|
triggers=[run_button_30.click, prompt.submit], |
|
|
fn=infer_30, |
|
|
inputs=[ |
|
|
prompt, |
|
|
negative_prompt, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
], |
|
|
outputs=[result, expanded_prompt_output], |
|
|
) |
|
|
|
|
|
gr.on( |
|
|
triggers=[run_button_60.click, prompt.submit], |
|
|
fn=infer_60, |
|
|
inputs=[ |
|
|
prompt, |
|
|
negative_prompt, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
], |
|
|
outputs=[result, expanded_prompt_output], |
|
|
) |
|
|
|
|
|
gr.on( |
|
|
triggers=[run_button_90.click, prompt.submit], |
|
|
fn=infer_90, |
|
|
inputs=[ |
|
|
prompt, |
|
|
negative_prompt, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
], |
|
|
outputs=[result, expanded_prompt_output], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |