padufour commited on
Commit
eb2afd5
·
verified ·
1 Parent(s): f9f834c

Upload generate.py

Browse files
Files changed (1) hide show
  1. generate.py +51 -0
generate.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from diffusers import StableDiffusionPipeline
4
+
5
+ # Chemin vers ton modèle .safetensors
6
+ model_path = "/Users/arthurdufour/Documents/ComfyUI/models/checkpoints/v1-5-pruned-emaonly.safetensors"
7
+
8
+ # Charger le modèle directement (évite load_state_dict)
9
+ pipeline = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float32)
10
+
11
+ # Vérification du backend MPS pour MacBook M3
12
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
13
+ pipeline.to(device)
14
+
15
+ def generate_image(positive_prompt, negative_prompt, steps, seed):
16
+ torch.mps.empty_cache() # Nettoyage mémoire
17
+ generator = torch.manual_seed(int(seed))
18
+
19
+ try:
20
+ image = pipeline(
21
+ prompt=positive_prompt,
22
+ negative_prompt=negative_prompt if "negative_prompt" in pipeline.__call__.__code__.co_varnames else None,
23
+ num_inference_steps=int(steps),
24
+ width=512,
25
+ height=512,
26
+ generator=generator
27
+ ).images[0]
28
+ except Exception as e:
29
+ return f"Erreur : {str(e)}"
30
+
31
+ return image
32
+
33
+ # Interface Gradio
34
+ with gr.Blocks() as demo:
35
+ gr.Markdown("## Génération d'images Stable Diffusion (MPS)")
36
+
37
+ with gr.Row():
38
+ prompt_input = gr.Textbox(label="Prompt Positif", value="a horse")
39
+ negative_input = gr.Textbox(label="Prompt Négatif", value="text, watermark")
40
+
41
+ with gr.Row():
42
+ steps_slider = gr.Slider(1, 50, 20, step=1, label="Nombre de Steps")
43
+ seed_input = gr.Number(value=580029479038533, label="Seed")
44
+
45
+ output_image = gr.Image(label="Image Générée")
46
+
47
+ generate_button = gr.Button("Générer")
48
+ generate_button.click(generate_image, inputs=[prompt_input, negative_input, steps_slider, seed_input], outputs=output_image)
49
+
50
+ # Lancer l'interface
51
+ demo.launch()