mobenta commited on
Commit
0b556fe
·
verified ·
1 Parent(s): 2812df0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -44
app.py CHANGED
@@ -3,14 +3,13 @@ import torch
3
  import numpy as np
4
  from PIL import Image
5
  import gradio as gr
6
- from nodes import NODE_CLASS_MAPPINGS
7
- from totoro_extras import nodes_custom_sampler
8
- from totoro_extras import nodes_flux
9
 
10
- # Set device to GPU if available
11
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
12
 
13
- # Load the necessary models and move them to the GPU
14
  CheckpointLoaderSimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
15
  LoraLoader = NODE_CLASS_MAPPINGS["LoraLoader"]()
16
  FluxGuidance = nodes_flux.NODE_CLASS_MAPPINGS["FluxGuidance"]()
@@ -23,14 +22,10 @@ VAELoader = NODE_CLASS_MAPPINGS["VAELoader"]()
23
  VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
24
  EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
25
 
26
- # Load checkpoint and move to GPU
27
  with torch.inference_mode():
28
- unet, clip, vae = CheckpointLoaderSimple.load_checkpoint("flux1-dev-fp8-all-in-one.safetensors")
29
- unet = unet.to(device)
30
- clip = clip.to(device)
31
- vae = vae.to(device)
32
 
33
- # Function to find the closest multiple of a number
34
  def closestNumber(n, m):
35
  q = int(n / m)
36
  n1 = m * q
@@ -42,51 +37,51 @@ def closestNumber(n, m):
42
  return n1
43
  return n2
44
 
45
- # Main generation function
46
  @torch.inference_mode()
47
  def generate(positive_prompt, width, height, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip):
48
  global unet, clip
49
  if seed == 0:
50
  seed = random.randint(0, 18446744073709551615)
51
- print(seed)
52
 
53
- # Load LoRA models and move them to GPU
54
- unet_lora, clip_lora = LoraLoader.load_lora(unet, clip, "flux_realism_lora.safetensors", lora_strength_model, lora_strength_clip)
55
- unet_lora = unet_lora.to(device)
56
- clip_lora = clip_lora.to(device)
57
 
58
- # Encode prompt and apply guidance
59
  cond, pooled = clip_lora.encode_from_tokens(clip_lora.tokenize(positive_prompt), return_pooled=True)
60
  cond = [[cond, {"pooled_output": pooled}]]
61
  cond = FluxGuidance.append(cond, guidance)[0]
62
-
63
- # Generate noise and move it to the GPU
64
- noise = RandomNoise.get_noise(seed)[0].to(device)
65
-
66
- # Setup guider and sampler
67
  guider = BasicGuider.get_guider(unet_lora, cond)[0]
68
  sampler = KSamplerSelect.get_sampler(sampler_name)[0]
69
-
70
- # Generate sigmas and latent image
71
  sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0]
72
- latent_image = EmptyLatentImage.generate(closestNumber(width, 16), closestNumber(height, 16))[0].to(device)
73
-
74
- # Perform sampling
 
 
75
  sample, sample_denoised = SamplerCustomAdvanced.sample(noise, guider, sampler, sigmas, latent_image)
76
-
77
- # Decode the latent image to a regular image
78
- decoded = VAEDecode.decode(vae, sample)[0].detach().cpu()
79
-
80
- # Convert to image and save
81
- output_image = Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0])
82
- output_image.save("/content/flux.png")
83
- return "/content/flux.png"
84
-
85
- # Setup the Gradio interface
86
  with gr.Blocks(analytics_enabled=False) as demo:
87
  with gr.Row():
88
  with gr.Column():
89
- positive_prompt = gr.Textbox(lines=3, interactive=True, value="cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black dress with a gold leaf pattern and a white apron eating a slice of an apple pie in the kitchen of an old dark victorian mansion with a bright window and very expensive stuff everywhere", label="Prompt")
 
 
 
 
 
90
  width = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="width")
91
  height = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="height")
92
  seed = gr.Slider(minimum=0, maximum=18446744073709551615, value=0, step=1, label="seed (0=random)")
@@ -94,13 +89,28 @@ with gr.Blocks(analytics_enabled=False) as demo:
94
  guidance = gr.Slider(minimum=0, maximum=20, value=3.5, step=0.5, label="guidance")
95
  lora_strength_model = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_model")
96
  lora_strength_clip = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_clip")
97
- sampler_name = gr.Dropdown(["euler", "heun", "heunpp2", "dpm_2", "lms", "dpmpp_2m", "ipndm", "deis", "ddim", "uni_pc", "uni_pc_bh2"], label="sampler_name", value="euler")
98
- scheduler = gr.Dropdown(["normal", "sgm_uniform", "simple", "ddim_uniform"], label="scheduler", value="simple")
 
 
 
 
 
 
 
 
99
  generate_button = gr.Button("Generate")
100
  with gr.Column():
101
  output_image = gr.Image(label="Generated image", interactive=False)
102
 
103
- generate_button.click(fn=generate, inputs=[positive_prompt, width, height, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip], outputs=output_image)
 
 
 
 
 
 
 
 
104
 
105
- # Launch the Gradio interface
106
  demo.queue().launch(inline=False, share=True, debug=True)
 
3
  import numpy as np
4
  from PIL import Image
5
  import gradio as gr
 
 
 
6
 
7
+ # Import nodes and custom modules
8
+ import nodes
9
+ from nodes import NODE_CLASS_MAPPINGS
10
+ from totoro_extras import nodes_custom_sampler, nodes_flux
11
 
12
+ # Initialize necessary components from the nodes
13
  CheckpointLoaderSimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
14
  LoraLoader = NODE_CLASS_MAPPINGS["LoraLoader"]()
15
  FluxGuidance = nodes_flux.NODE_CLASS_MAPPINGS["FluxGuidance"]()
 
22
  VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
23
  EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
24
 
25
+ # Load checkpoints and models
26
  with torch.inference_mode():
27
+ unet, clip, vae = CheckpointLoaderSimple.load_checkpoint("models/checkpoints/flux1-dev-fp8-all-in-one.safetensors")
 
 
 
28
 
 
29
  def closestNumber(n, m):
30
  q = int(n / m)
31
  n1 = m * q
 
37
  return n1
38
  return n2
39
 
 
40
  @torch.inference_mode()
41
  def generate(positive_prompt, width, height, seed, steps, sampler_name, scheduler, guidance, lora_strength_model, lora_strength_clip):
42
  global unet, clip
43
  if seed == 0:
44
  seed = random.randint(0, 18446744073709551615)
45
+ print(f"Seed used: {seed}")
46
 
47
+ # Load LoRA models
48
+ unet_lora, clip_lora = LoraLoader.load_lora(unet, clip, "models/loras/flux_realism_lora.safetensors", lora_strength_model, lora_strength_clip)
 
 
49
 
50
+ # Encode the prompt
51
  cond, pooled = clip_lora.encode_from_tokens(clip_lora.tokenize(positive_prompt), return_pooled=True)
52
  cond = [[cond, {"pooled_output": pooled}]]
53
  cond = FluxGuidance.append(cond, guidance)[0]
54
+
55
+ # Generate noise
56
+ noise = RandomNoise.get_noise(seed)[0]
57
+
58
+ # Get guider and sampler
59
  guider = BasicGuider.get_guider(unet_lora, cond)[0]
60
  sampler = KSamplerSelect.get_sampler(sampler_name)[0]
61
+
62
+ # Get scheduling sigmas
63
  sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0]
64
+
65
+ # Generate latent image
66
+ latent_image = EmptyLatentImage.generate(closestNumber(width, 16), closestNumber(height, 16))[0]
67
+
68
+ # Sample and decode the image
69
  sample, sample_denoised = SamplerCustomAdvanced.sample(noise, guider, sampler, sigmas, latent_image)
70
+ decoded = VAEDecode.decode(vae, sample)[0].detach()
71
+
72
+ # Convert to image and return
73
+ return Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0])
74
+
75
+ # Define Gradio interface
 
 
 
 
76
  with gr.Blocks(analytics_enabled=False) as demo:
77
  with gr.Row():
78
  with gr.Column():
79
+ positive_prompt = gr.Textbox(
80
+ lines=3,
81
+ interactive=True,
82
+ value="cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black dress with a gold leaf pattern and a white apron eating a slice of an apple pie in the kitchen of an old dark victorian mansion with a bright window and very expensive stuff everywhere",
83
+ label="Prompt"
84
+ )
85
  width = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="width")
86
  height = gr.Slider(minimum=256, maximum=2048, value=1024, step=16, label="height")
87
  seed = gr.Slider(minimum=0, maximum=18446744073709551615, value=0, step=1, label="seed (0=random)")
 
89
  guidance = gr.Slider(minimum=0, maximum=20, value=3.5, step=0.5, label="guidance")
90
  lora_strength_model = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_model")
91
  lora_strength_clip = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.1, label="lora_strength_clip")
92
+ sampler_name = gr.Dropdown(
93
+ ["euler", "heun", "heunpp2", "dpm_2", "lms", "dpmpp_2m", "ipndm", "deis", "ddim", "uni_pc", "uni_pc_bh2"],
94
+ label="sampler_name",
95
+ value="euler"
96
+ )
97
+ scheduler = gr.Dropdown(
98
+ ["normal", "sgm_uniform", "simple", "ddim_uniform"],
99
+ label="scheduler",
100
+ value="simple"
101
+ )
102
  generate_button = gr.Button("Generate")
103
  with gr.Column():
104
  output_image = gr.Image(label="Generated image", interactive=False)
105
 
106
+ generate_button.click(
107
+ fn=generate,
108
+ inputs=[
109
+ positive_prompt, width, height, seed, steps,
110
+ sampler_name, scheduler, guidance,
111
+ lora_strength_model, lora_strength_clip
112
+ ],
113
+ outputs=output_image
114
+ )
115
 
 
116
  demo.queue().launch(inline=False, share=True, debug=True)