File size: 2,857 Bytes
f9fbfff
39efaa4
 
f9fbfff
39efaa4
f9fbfff
39efaa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9fbfff
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

import gradio as gr
from gnx_flux_lora import import_custom_nodes, NODE_CLASS_MAPPINGS, get_value_at_index
import torch
import random

# Initialize once (model loading outside inference function for Hugging Face ZeroGPU)
import_custom_nodes()

unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
unetloader_3 = unetloader.load_unet(
    unet_name="flux1-dev.sft", weight_dtype="default"
)

dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
dualcliploader_8 = dualcliploader.load_clip(
    clip_name1="clip_l.safetensors",
    clip_name2="t5xxl_fp16.safetensors",
    type="flux",
    device="default",
)

loraloader = NODE_CLASS_MAPPINGS["LoraLoader"]()
loraloader_12 = loraloader.load_lora(
    lora_name="gauravjuneja4/gauravjuneja4.safetensors",
    strength_model=1,
    strength_clip=1,
    model=get_value_at_index(unetloader_3, 0),
    clip=get_value_at_index(dualcliploader_8, 0),
)

vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vae_9 = vaeloader.load_vae(vae_name="ae.safetensors")

cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()

def run_inference(prompt: str):
    with torch.inference_mode():
        # Text encodings
        positive = cliptextencode.encode(
            text=prompt,
            clip=get_value_at_index(loraloader_12, 1),
        )

        negative = cliptextencode.encode(
            text="",
            clip=get_value_at_index(loraloader_12, 1),
        )

        latent = emptylatentimage.generate(
            width=1024, height=1024, batch_size=1
        )

        guided = fluxguidance.append(
            guidance=3.5, conditioning=get_value_at_index(positive, 0)
        )

        sample = ksampler.sample(
            seed=random.randint(1, 2**64),
            steps=20,
            cfg=1,
            sampler_name="euler",
            scheduler="simple",
            denoise=1,
            model=get_value_at_index(loraloader_12, 0),
            positive=get_value_at_index(guided, 0),
            negative=get_value_at_index(negative, 0),
            latent_image=get_value_at_index(latent, 0),
        )

        decoded = vaedecode.decode(
            samples=get_value_at_index(sample, 0),
            vae=get_value_at_index(vae_9, 0),
        )

        result = get_value_at_index(decoded, 0)
        return result

# Gradio UI
demo = gr.Interface(
    fn=run_inference,
    inputs=gr.Textbox(label="Prompt", placeholder="e.g. gjnx is driving ferrari on a road in germany"),
    outputs=gr.Image(label="Generated Image"),
    title="GNX Flux LoRA Generator",
    description="Enter a prompt using your trained LoRA with Flux .1 Dev."
)

if __name__ == "__main__":
    demo.launch()