vcollos commited on
Commit
45467a5
·
verified ·
1 Parent(s): 58ff5c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -41
app.py CHANGED
@@ -8,32 +8,37 @@ import os
8
  import json
9
  from gradio_client import Client as client_gradio
10
  from supabase import create_client, Client
 
11
 
12
- # Initialize supabase
13
  url: str = os.getenv('SUPABASE_URL')
14
  key: str = os.getenv('SUPABASE_KEY')
15
  supabase: Client = create_client(url, key)
16
 
17
- # Get Hugging Face token from secrets
18
  hf_token = os.getenv("HF_TOKEN")
19
 
20
- # Initialize the base model
21
  base_model = "black-forest-labs/FLUX.1-dev"
22
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
23
 
24
- # Load the first LoRA model
25
  lora_repo_1 = "markury/AndroFlux"
26
- pipe.load_lora_weights(lora_repo_1, weight_name="AndroFlux-v19.safetensors")
27
-
28
- # Load the second LoRA model
29
  lora_repo_2 = "vcollos/VitorCollos"
30
- pipe.load_lora_weights(lora_repo_2, weight_name="Vitor.safetensors")
31
 
32
- # Combine LoRA weights manually
 
 
 
 
 
 
 
 
 
33
  def combine_lora_weights(pipe, weight_1, weight_2):
34
- for name, module in pipe.unet.named_modules():
35
- if "lora" in name:
36
- # Combine the weights of the two LoRA models
37
  module.weight.data = weight_1 * module.weight.data + weight_2 * module.weight.data
38
 
39
  pipe.to("cuda")
@@ -42,12 +47,12 @@ MAX_SEED = 2**32 - 1
42
 
43
  @spaces.GPU(duration=80)
44
  def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale_1, lora_scale_2, progress=gr.Progress(track_tqdm=True)):
45
- # Set random seed for reproducibility
46
  if randomize_seed:
47
  seed = random.randint(0, MAX_SEED)
48
  generator = torch.Generator(device="cuda").manual_seed(seed)
49
 
50
- # Moderation
51
  moderation_client = client_gradio("duchaba/Friendly_Text_Moderation")
52
  result = moderation_client.predict(
53
  msg=f"{prompt}",
@@ -56,20 +61,20 @@ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora
56
  )
57
 
58
  if float(json.loads(result[1])['sexual_minors']) > 0.03:
59
- print('Minors')
60
  response_data = (supabase.table("requests")
61
  .insert({"prompt": prompt, "cfg_scale": cfg_scale, "steps": steps, "randomized_seed": randomize_seed, "seed": seed, "lora_scale_1": lora_scale_1, "lora_scale_2": lora_scale_2, "moderated": 'true'})
62
  .execute()
63
  )
64
- raise gr.Error("Unauthorized request 💥!")
65
 
66
- # Update progress bar (0% saat mulai)
67
- progress(0, "Starting image generation...")
68
 
69
- # Combine LoRA weights
70
  combine_lora_weights(pipe, lora_scale_1, lora_scale_2)
71
 
72
- # Generate image using the pipeline
73
  image = pipe(
74
  prompt=f"{prompt}",
75
  num_inference_steps=steps,
@@ -80,55 +85,55 @@ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora
80
  max_sequence_length=512
81
  ).images[0]
82
 
83
- # Save the image to a file with a unique name in /tmp directory
84
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
85
  image_filename = f"generated_image_{timestamp}.png"
86
  image_path = os.path.join("/tmp/gradio", image_filename)
87
 
88
- # Add Metadata
89
- new_metadata_string = f"{prompt}\nNegative prompt: none\nSteps: {steps}, CFG scale: {cfg_scale}, Seed: {seed}, Lora hashes: AndroFlux-v19: c44afd41ece1, VitorCollos: <hash>"
90
  metadata = PngImagePlugin.PngInfo()
91
  metadata.add_text("parameters", new_metadata_string)
92
 
93
- # Save the tmp image
94
  image.save(image_path, pnginfo=metadata)
95
 
96
- # Log queries
97
  try:
98
  if "girl" not in prompt and "woman" not in prompt:
99
- # Save image in supabase
100
  response = supabase.storage.from_('generated_images').upload(image_filename, image_path, file_options={"content-type": "image/png;charset=UTF-8"})
101
  print(response.dict)
102
- # Log request in supabase
 
103
  response_data = (supabase.table("requests")
104
  .insert({"prompt": prompt, "cfg_scale": cfg_scale, "steps": steps, "randomized_seed": randomize_seed, "seed": seed, "lora_scale_1": lora_scale_1, "lora_scale_2": lora_scale_2, "image_url": response.full_path})
105
  .execute()
106
  )
107
-
108
  except Exception as error:
109
- # handle the exception
110
- print("An exception occurred:", error)
111
 
112
  yield image, seed
113
 
 
114
  gr_theme = os.getenv("THEME")
115
  with gr.Blocks(theme=gr_theme) as app:
116
  gr.Markdown("# Androflux Image Generator")
117
  with gr.Row():
118
  with gr.Column(scale=3):
119
- prompt = gr.TextArea(label="Prompt", placeholder="Type a prompt of max 77 characters", lines=3)
120
- generate_button = gr.Button("Generate")
121
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5) # Use um valor padrão
122
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=25) # Use um valor padrão
123
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=896) # Use um valor padrão
124
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1152) # Use um valor padrão
125
  randomize_seed = gr.Checkbox(False, label="Randomize seed")
126
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=556215326) # Use um valor padrão
127
- lora_scale_1 = gr.Slider(label="LoRA Scale (AndroFlux)", minimum=0, maximum=1, step=0.01, value=0.7) # Use um valor padrão
128
- lora_scale_2 = gr.Slider(label="LoRA Scale (VitorCollos)", minimum=0, maximum=1, step=0.01, value=1) # Use um valor padrão
129
  with gr.Column(scale=1):
130
  result = gr.Image(label="Generated Image")
131
- gr.Markdown("Generate images using Androflux Lora and a text prompt.\n[[non-commercial license, Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]")
132
 
133
  generate_button.click(
134
  run_lora,
@@ -137,4 +142,4 @@ with gr.Blocks(theme=gr_theme) as app:
137
  )
138
 
139
  app.queue()
140
- app.launch(share=True) # Set `share=True` to create a public link
 
8
  import json
9
  from gradio_client import Client as client_gradio
10
  from supabase import create_client, Client
11
+ from datetime import datetime
12
 
13
+ # Inicializa supabase
14
  url: str = os.getenv('SUPABASE_URL')
15
  key: str = os.getenv('SUPABASE_KEY')
16
  supabase: Client = create_client(url, key)
17
 
18
+ # Obtém token da Hugging Face
19
  hf_token = os.getenv("HF_TOKEN")
20
 
21
+ # Inicializa o modelo base
22
  base_model = "black-forest-labs/FLUX.1-dev"
23
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
24
 
25
+ # Carrega os adaptadores LoRA
26
  lora_repo_1 = "markury/AndroFlux"
 
 
 
27
  lora_repo_2 = "vcollos/VitorCollos"
 
28
 
29
+ try:
30
+ pipe.load_lora_weights(lora_repo_1, weight_name="AndroFlux-v19.safetensors")
31
+ print("✅ Primeiro LoRA carregado")
32
+
33
+ pipe.load_lora_weights(lora_repo_2, weight_name="Vitor.safetensors")
34
+ print("✅ Segundo LoRA carregado")
35
+ except Exception as e:
36
+ print(f"❌ Erro ao carregar os LoRA adapters: {e}")
37
+
38
+ # Função para combinar os pesos dos LoRA
39
  def combine_lora_weights(pipe, weight_1, weight_2):
40
+ for name, module in pipe.named_modules(): # Percorre os módulos do pipeline
41
+ if hasattr(module, "weight") and module.weight is not None:
 
42
  module.weight.data = weight_1 * module.weight.data + weight_2 * module.weight.data
43
 
44
  pipe.to("cuda")
 
47
 
48
  @spaces.GPU(duration=80)
49
  def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale_1, lora_scale_2, progress=gr.Progress(track_tqdm=True)):
50
+ # Define uma seed aleatória se necessário
51
  if randomize_seed:
52
  seed = random.randint(0, MAX_SEED)
53
  generator = torch.Generator(device="cuda").manual_seed(seed)
54
 
55
+ # Moderação de texto
56
  moderation_client = client_gradio("duchaba/Friendly_Text_Moderation")
57
  result = moderation_client.predict(
58
  msg=f"{prompt}",
 
61
  )
62
 
63
  if float(json.loads(result[1])['sexual_minors']) > 0.03:
64
+ print('🔴 Conteúdo não permitido')
65
  response_data = (supabase.table("requests")
66
  .insert({"prompt": prompt, "cfg_scale": cfg_scale, "steps": steps, "randomized_seed": randomize_seed, "seed": seed, "lora_scale_1": lora_scale_1, "lora_scale_2": lora_scale_2, "moderated": 'true'})
67
  .execute()
68
  )
69
+ raise gr.Error("🚫 Requisição não autorizada!")
70
 
71
+ # Atualiza barra de progresso (0% no início)
72
+ progress(0, "Iniciando a geração de imagem...")
73
 
74
+ # Combina os LoRA weights corretamente
75
  combine_lora_weights(pipe, lora_scale_1, lora_scale_2)
76
 
77
+ # Gera imagem com o pipeline
78
  image = pipe(
79
  prompt=f"{prompt}",
80
  num_inference_steps=steps,
 
85
  max_sequence_length=512
86
  ).images[0]
87
 
88
+ # Salva a imagem em um diretório temporário
89
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
90
  image_filename = f"generated_image_{timestamp}.png"
91
  image_path = os.path.join("/tmp/gradio", image_filename)
92
 
93
+ # Adiciona metadados à imagem
94
+ new_metadata_string = f"{prompt}\nNegative prompt: none\nSteps: {steps}, CFG scale: {cfg_scale}, Seed: {seed}, Lora hashes: AndroFlux-v19, VitorCollos"
95
  metadata = PngImagePlugin.PngInfo()
96
  metadata.add_text("parameters", new_metadata_string)
97
 
98
+ # Salva a imagem gerada
99
  image.save(image_path, pnginfo=metadata)
100
 
101
+ # Registra a imagem no Supabase
102
  try:
103
  if "girl" not in prompt and "woman" not in prompt:
104
+ # Salva a imagem no Supabase Storage
105
  response = supabase.storage.from_('generated_images').upload(image_filename, image_path, file_options={"content-type": "image/png;charset=UTF-8"})
106
  print(response.dict)
107
+
108
+ # Registra a requisição no Supabase
109
  response_data = (supabase.table("requests")
110
  .insert({"prompt": prompt, "cfg_scale": cfg_scale, "steps": steps, "randomized_seed": randomize_seed, "seed": seed, "lora_scale_1": lora_scale_1, "lora_scale_2": lora_scale_2, "image_url": response.full_path})
111
  .execute()
112
  )
 
113
  except Exception as error:
114
+ print("⚠️ Erro ao salvar no Supabase:", error)
 
115
 
116
  yield image, seed
117
 
118
+ # Interface Gradio
119
  gr_theme = os.getenv("THEME")
120
  with gr.Blocks(theme=gr_theme) as app:
121
  gr.Markdown("# Androflux Image Generator")
122
  with gr.Row():
123
  with gr.Column(scale=3):
124
+ prompt = gr.TextArea(label="Prompt", placeholder="Digite um prompt (máx 77 caracteres)", lines=3)
125
+ generate_button = gr.Button("Gerar")
126
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
127
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=25)
128
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=896)
129
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1152)
130
  randomize_seed = gr.Checkbox(False, label="Randomize seed")
131
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=556215326)
132
+ lora_scale_1 = gr.Slider(label="LoRA Scale (AndroFlux)", minimum=0, maximum=1, step=0.01, value=0.7)
133
+ lora_scale_2 = gr.Slider(label="LoRA Scale (VitorCollos)", minimum=0, maximum=1, step=0.01, value=1)
134
  with gr.Column(scale=1):
135
  result = gr.Image(label="Generated Image")
136
+ gr.Markdown("Gere imagens usando Androflux LoRA e um prompt de texto.\n[[Licença não comercial, Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]")
137
 
138
  generate_button.click(
139
  run_lora,
 
142
  )
143
 
144
  app.queue()
145
+ app.launch(share=True) # `share=True` cria um link público