gjnxLora / app.py
techyygarry's picture
Upload 5 files
39efaa4 verified
import gradio as gr
from gnx_flux_lora import import_custom_nodes, NODE_CLASS_MAPPINGS, get_value_at_index
import torch
import random
# Initialize once (model loading outside inference function for Hugging Face ZeroGPU)
import_custom_nodes()
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
unetloader_3 = unetloader.load_unet(
unet_name="flux1-dev.sft", weight_dtype="default"
)
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
dualcliploader_8 = dualcliploader.load_clip(
clip_name1="clip_l.safetensors",
clip_name2="t5xxl_fp16.safetensors",
type="flux",
device="default",
)
loraloader = NODE_CLASS_MAPPINGS["LoraLoader"]()
loraloader_12 = loraloader.load_lora(
lora_name="gauravjuneja4/gauravjuneja4.safetensors",
strength_model=1,
strength_clip=1,
model=get_value_at_index(unetloader_3, 0),
clip=get_value_at_index(dualcliploader_8, 0),
)
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vae_9 = vaeloader.load_vae(vae_name="ae.safetensors")
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
def run_inference(prompt: str):
with torch.inference_mode():
# Text encodings
positive = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(loraloader_12, 1),
)
negative = cliptextencode.encode(
text="",
clip=get_value_at_index(loraloader_12, 1),
)
latent = emptylatentimage.generate(
width=1024, height=1024, batch_size=1
)
guided = fluxguidance.append(
guidance=3.5, conditioning=get_value_at_index(positive, 0)
)
sample = ksampler.sample(
seed=random.randint(1, 2**64),
steps=20,
cfg=1,
sampler_name="euler",
scheduler="simple",
denoise=1,
model=get_value_at_index(loraloader_12, 0),
positive=get_value_at_index(guided, 0),
negative=get_value_at_index(negative, 0),
latent_image=get_value_at_index(latent, 0),
)
decoded = vaedecode.decode(
samples=get_value_at_index(sample, 0),
vae=get_value_at_index(vae_9, 0),
)
result = get_value_at_index(decoded, 0)
return result
# Gradio UI
demo = gr.Interface(
fn=run_inference,
inputs=gr.Textbox(label="Prompt", placeholder="e.g. gjnx is driving ferrari on a road in germany"),
outputs=gr.Image(label="Generated Image"),
title="GNX Flux LoRA Generator",
description="Enter a prompt using your trained LoRA with Flux .1 Dev."
)
if __name__ == "__main__":
demo.launch()