Leteint commited on
Commit
80dc654
·
verified ·
1 Parent(s): 0996984

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -30
app.py CHANGED
@@ -6,7 +6,7 @@ from huggingface_hub import hf_hub_download, login
6
  import random
7
  import os
8
 
9
- # Authentification (gardez ça)
10
  hf_token = os.getenv("HF_TOKEN")
11
  print(f"Token trouvé : {bool(hf_token)}")
12
  if hf_token:
@@ -23,40 +23,121 @@ def load_lora(repo_id, style):
23
  filename = f"{style}_lora.safetensors"
24
  lora_path = hf_hub_download(repo_id=repo_id, filename=filename)
25
  lora_repo = repo_id
26
- return f"✅ LoRA: {repo_id} ({filename})"
27
  except Exception as e:
28
- return f"❌ {e}"
29
 
30
  @spaces.GPU(duration=120)
31
- def generate(prompt, negative_prompt, width=1024, height=1024, steps=4, seed=-1, lora_scale=0.8):
32
- pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
33
- pipe.to("cuda")
34
-
35
- if lora_repo and lora_path:
36
- pipe.load_lora_weights(lora_path)
37
- pipe.fuse_lora(lora_scale=lora_scale) # Après to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- pipe.enable_model_cpu_offload() # Optim VRAM
 
 
 
 
 
 
40
 
41
- generator = torch.Generator("cuda").manual_seed(seed if seed != -1 else random.randint(0, 2**32))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- image = pipe(
44
- prompt,
45
- negative_prompt=negative_prompt,
46
- height=height, width=width,
47
- num_inference_steps=steps,
48
- guidance_scale=0.0,
49
- generator=generator,
50
- max_sequence_length=256 # Pour FLUX
51
- ).images[0]
52
 
53
- del pipe # Nettoyage
54
- torch.cuda.empty_cache()
55
- return image
56
-
57
- # Interface Gradio (inchangée)
58
- with gr.Blocks(title="Flux Schnell + LoRA") as demo:
59
- # ... (votre interface reste identique)
60
- pass
61
 
62
- demo.launch()
 
 
 
 
 
 
 
 
 
6
  import random
7
  import os
8
 
9
+ # Authentification
10
  hf_token = os.getenv("HF_TOKEN")
11
  print(f"Token trouvé : {bool(hf_token)}")
12
  if hf_token:
 
23
  filename = f"{style}_lora.safetensors"
24
  lora_path = hf_hub_download(repo_id=repo_id, filename=filename)
25
  lora_repo = repo_id
26
+ return f"✅ LoRA chargé : {repo_id} ({filename})"
27
  except Exception as e:
28
+ return f"❌ Erreur : {str(e)}"
29
 
30
  @spaces.GPU(duration=120)
31
+ def generate(prompt, negative_prompt="", width=1024, height=1024, steps=4, seed=-1, lora_scale=0.8):
32
+ try:
33
+ # Initialisation du pipeline
34
+ pipe = FluxPipeline.from_pretrained(
35
+ model_id,
36
+ torch_dtype=torch.bfloat16
37
+ )
38
+ pipe.to("cuda")
39
+
40
+ # Chargement du LoRA si disponible
41
+ if lora_repo and lora_path:
42
+ pipe.load_lora_weights(lora_path)
43
+ pipe.fuse_lora(lora_scale=lora_scale)
44
+
45
+ # Optimisation mémoire
46
+ pipe.enable_model_cpu_offload()
47
+
48
+ # Génération du seed
49
+ if seed == -1:
50
+ seed = random.randint(0, 2**32 - 1)
51
+ generator = torch.Generator("cuda").manual_seed(seed)
52
+
53
+ # Génération de l'image
54
+ image = pipe(
55
+ prompt=prompt,
56
+ negative_prompt=negative_prompt if negative_prompt else None,
57
+ height=height,
58
+ width=width,
59
+ num_inference_steps=steps,
60
+ guidance_scale=0.0,
61
+ generator=generator,
62
+ max_sequence_length=256
63
+ ).images[0]
64
+
65
+ # Nettoyage de la mémoire
66
+ del pipe
67
+ torch.cuda.empty_cache()
68
+
69
+ return image
70
 
71
+ except Exception as e:
72
+ print(f"❌ Erreur de génération : {str(e)}")
73
+ raise gr.Error(f"Erreur : {str(e)}")
74
+
75
+ # Interface Gradio
76
+ with gr.Blocks(title="Flux Schnell + LoRA") as demo:
77
+ gr.Markdown("# 🎨 Générateur Flux Schnell + LoRA")
78
 
79
+ with gr.Row():
80
+ with gr.Column():
81
+ # LoRA
82
+ gr.Markdown("### Charger un LoRA (optionnel)")
83
+ lora_repo_input = gr.Textbox(
84
+ label="Repository LoRA",
85
+ placeholder="username/repo-name"
86
+ )
87
+ lora_style = gr.Textbox(
88
+ label="Style LoRA",
89
+ placeholder="style_name",
90
+ value="style"
91
+ )
92
+ load_btn = gr.Button("Charger LoRA")
93
+ lora_status = gr.Textbox(label="Status", interactive=False)
94
+
95
+ # Paramètres
96
+ gr.Markdown("### Paramètres de génération")
97
+ prompt = gr.Textbox(
98
+ label="Prompt",
99
+ placeholder="Décrivez votre image...",
100
+ lines=3
101
+ )
102
+ negative_prompt = gr.Textbox(
103
+ label="Negative Prompt (optionnel)",
104
+ placeholder="Ce que vous ne voulez pas...",
105
+ lines=2
106
+ )
107
+
108
+ with gr.Row():
109
+ width = gr.Slider(512, 2048, 1024, step=64, label="Largeur")
110
+ height = gr.Slider(512, 2048, 1024, step=64, label="Hauteur")
111
+
112
+ with gr.Row():
113
+ steps = gr.Slider(1, 10, 4, step=1, label="Steps")
114
+ seed = gr.Number(label="Seed (-1 = aléatoire)", value=-1)
115
+ lora_scale = gr.Slider(0, 1, 0.8, step=0.1, label="LoRA Scale")
116
+
117
+ generate_btn = gr.Button("🚀 Générer", variant="primary")
118
+
119
+ with gr.Column():
120
+ output_image = gr.Image(label="Image générée", type="pil")
121
 
122
+ # Actions
123
+ load_btn.click(
124
+ fn=load_lora,
125
+ inputs=[lora_repo_input, lora_style],
126
+ outputs=lora_status
127
+ )
 
 
 
128
 
129
+ generate_btn.click(
130
+ fn=generate,
131
+ inputs=[prompt, negative_prompt, width, height, steps, seed, lora_scale],
132
+ outputs=output_image
133
+ )
 
 
 
134
 
135
+ # Lancement avec gestion propre de l'event loop
136
+ if __name__ == "__main__":
137
+ demo.queue() # Active la file d'attente
138
+ demo.launch(
139
+ server_name="0.0.0.0",
140
+ server_port=7860,
141
+ show_error=True,
142
+ quiet=False
143
+ )