cheeseman182 commited on
Commit
6c16202
·
verified ·
1 Parent(s): 2dac743

updated main code to dynamically detect device

Browse files
Files changed (1) hide show
  1. media.py +151 -151
media.py CHANGED
@@ -1,152 +1,152 @@
1
- # --- LIBRARIES ---
2
- import torch
3
- import gradio as gr
4
- import random
5
- import time
6
- from diffusers import AutoPipelineForText2Image, TextToVideoSDPipeline
7
- import gc
8
- import os
9
- import imageio
10
-
11
- # --- AUTHENTICATION FOR HUGGING FACE SPACES ---
12
- # This will read the token from a "Secret" you set in your Space's settings
13
- # It's more secure and the correct way to do it on HF Spaces.
14
- try:
15
- from huggingface_hub import login
16
- HF_TOKEN = os.environ.get('HF_TOKEN')
17
- if HF_TOKEN:
18
- login(token=HF_TOKEN)
19
- print("✅ Hugging Face Authentication successful.")
20
- else:
21
- print("⚠️ Hugging Face token not found in Space Secrets. Gated models may not be available.")
22
- except ImportError:
23
- print("Could not import huggingface_hub. Please ensure it's in requirements.txt")
24
-
25
- # --- CONFIGURATION & STATE ---
26
- available_models = {
27
- "Fast Image (SDXL Turbo)": "stabilityai/sdxl-turbo",
28
- "Quality Image (SDXL)": "stabilityai/stable-diffusion-xl-base-1.0",
29
- "Video (Zeroscope)": "cerspense/zeroscope-v2-576w"
30
- }
31
- model_state = { "current_pipe": None, "loaded_model_name": None }
32
-
33
-
34
- # --- CORE GENERATION FUNCTION ---
35
- # This is a generator function, which yields updates to the UI.
36
- def generate_media(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
37
- # --- Model Loading Logic ---
38
- # If the requested model isn't the one we have loaded, switch them.
39
- if model_state.get("loaded_model_name") != model_key:
40
- print(f"Switching to {model_key}. Unloading previous model...")
41
- yield {status_textbox: f"Unloading previous model..."} # UI Update
42
- if model_state.get("current_pipe"):
43
- del model_state["current_pipe"]
44
- gc.collect()
45
- torch.cuda.empty_cache()
46
-
47
- model_id = available_models[model_key]
48
- print(f"Loading {model_id}...")
49
- yield {status_textbox: f"Loading {model_id}... This can take a minute."} # UI Update
50
-
51
- # Load the correct pipeline based on model type
52
- if "Image" in model_key:
53
- pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
54
- elif "Video" in model_key:
55
- pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
56
-
57
- pipe.to("cuda")
58
- # Offload larger models to save VRAM, but keep fast models fully on GPU
59
- if "Turbo" not in model_key and "Video" not in model_key:
60
- pipe.enable_model_cpu_offload()
61
-
62
- model_state["current_pipe"] = pipe
63
- model_state["loaded_model_name"] = model_key
64
- print("✅ Model loaded successfully.")
65
-
66
- pipe = model_state["current_pipe"]
67
- generator = torch.Generator("cuda").manual_seed(seed)
68
- yield {status_textbox: f"Generating with {model_key}..."} # UI Update
69
-
70
- # --- Generation Logic ---
71
- if "Image" in model_key:
72
- print("Generating image...")
73
- if "Turbo" in model_key: # Special settings for SDXL Turbo
74
- num_steps, guidance_scale = 1, 0.0
75
- else:
76
- num_steps, guidance_scale = int(steps), float(cfg_scale)
77
-
78
- image = pipe(
79
- prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps,
80
- guidance_scale=guidance_scale, width=int(width), height=int(height), generator=generator
81
- ).images[0]
82
- print(" Image generation complete.")
83
- yield {output_image: image, output_video: None, status_textbox: f"Seed used: {seed}"}
84
-
85
- elif "Video" in model_key:
86
- print("Generating video...")
87
- video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
88
-
89
- video_path = f"/tmp/video_{seed}.mp4"
90
- imageio.mimsave(video_path, video_frames, fps=12)
91
- print(f"✅ Video saved to {video_path}")
92
- yield {output_image: None, output_video: video_path, status_textbox: f"Seed used: {seed}"}
93
-
94
-
95
- # --- GRADIO USER INTERFACE ---
96
- with gr.Blocks(theme='gradio/soft') as demo:
97
- gr.Markdown("# The Generative Media Suite")
98
- gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182.")
99
- seed_state = gr.State(-1)
100
-
101
- with gr.Row():
102
- with gr.Column(scale=2):
103
- model_selector = gr.Radio(label="Select Model", choices=list(available_models.keys()), value=list(available_models.keys())[0])
104
- prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="An astronaut riding a horse on Mars, cinematic...")
105
- negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, deformed, watermark, text")
106
-
107
- with gr.Accordion("Settings", open=True):
108
- steps_slider = gr.Slider(1, 100, 30, step=1, label="Inference Steps")
109
- cfg_slider = gr.Slider(0.0, 15.0, 7.5, step=0.5, label="Guidance Scale (CFG)")
110
- with gr.Row():
111
- width_slider = gr.Slider(256, 1024, 768, step=64, label="Width")
112
- height_slider = gr.Slider(256, 1024, 768, step=64, label="Height")
113
- num_frames_slider = gr.Slider(12, 48, 24, step=4, label="Video Frames", visible=False)
114
- seed_input = gr.Number(-1, label="Seed (-1 for random)")
115
-
116
- generate_button = gr.Button("Generate", variant="primary")
117
-
118
- with gr.Column(scale=3):
119
- output_image = gr.Image(label="Image Result", interactive=False, height="60vh", visible=True)
120
- output_video = gr.Video(label="Video Result", interactive=False, height="60vh", visible=False)
121
- status_textbox = gr.Textbox(label="Status", interactive=False)
122
-
123
- # --- UI Logic ---
124
- def update_ui_on_model_change(model_key):
125
- is_video = "Video" in model_key
126
- is_turbo = "Turbo" in model_key
127
- return {
128
- steps_slider: gr.update(interactive=not is_turbo, value=1 if is_turbo else 30),
129
- cfg_slider: gr.update(interactive=not is_turbo, value=0.0 if is_turbo else 7.5),
130
- width_slider: gr.update(visible=not is_video),
131
- height_slider: gr.update(visible=not is_video),
132
- num_frames_slider: gr.update(visible=is_video),
133
- output_image: gr.update(visible=not is_video),
134
- output_video: gr.update(visible=is_video)
135
- }
136
- model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video])
137
-
138
- # --- Button Logic ---
139
- # This chain first sets the seed, then calls the main generation function.
140
- click_event = generate_button.click(
141
- fn=lambda s: (s if s != -1 else random.randint(0, 2**32 - 1)),
142
- inputs=seed_input,
143
- outputs=seed_state,
144
- queue=False
145
- ).then(
146
- fn=generate_media,
147
- inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
148
- outputs=[output_image, output_video, status_textbox]
149
- )
150
-
151
- # This is the correct way to launch on Hugging Face Spaces
152
  demo.launch()
 
1
+ # --- LIBRARIES ---
2
+ import torch
3
+ import gradio as gr
4
+ import random
5
+ import time
6
+ from diffusers import AutoPipelineForText2Image, TextToVideoSDPipeline
7
+ import gc
8
+ import os
9
+ import imageio
10
+
11
+ # --- DYNAMIC HARDWARE DETECTION (THE FIX) ---
12
+ # Check if a CUDA-enabled GPU is available, otherwise use the CPU
13
+ if torch.cuda.is_available():
14
+ device = "cuda"
15
+ torch_dtype = torch.float16 # Use float16 for GPU
16
+ print("✅ GPU detected. Using CUDA.")
17
+ else:
18
+ device = "cpu"
19
+ torch_dtype = torch.float32 # Use float32 for CPU
20
+ print("⚠️ No GPU detected. Using CPU. Performance will be slower.")
21
+
22
+
23
+ # --- AUTHENTICATION FOR HUGGING FACE SPACES ---
24
+ try:
25
+ from huggingface_hub import login
26
+ HF_TOKEN = os.environ.get('HF_TOKEN')
27
+ if HF_TOKEN:
28
+ login(token=HF_TOKEN)
29
+ print(" Hugging Face Authentication successful.")
30
+ else:
31
+ print("⚠️ Hugging Face token not found in Space Secrets. Gated models may not be available.")
32
+ except ImportError:
33
+ print("Could not import huggingface_hub. Please ensure it's in requirements.txt")
34
+
35
+ # --- CONFIGURATION & STATE ---
36
+ available_models = {
37
+ "Fast Image (SDXL Turbo)": "stabilityai/sdxl-turbo",
38
+ "Quality Image (SDXL)": "stabilityai/stable-diffusion-xl-base-1.0",
39
+ "Video (Zeroscope)": "cerspense/zeroscope-v2-576w"
40
+ }
41
+ model_state = { "current_pipe": None, "loaded_model_name": None }
42
+
43
+
44
+ # --- CORE GENERATION FUNCTION ---
45
+ def generate_media(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
46
+ if model_state.get("loaded_model_name") != model_key:
47
+ print(f"Switching to {model_key}. Unloading previous model...")
48
+ yield {status_textbox: f"Unloading previous model..."}
49
+ if model_state.get("current_pipe"):
50
+ del model_state["current_pipe"]
51
+ gc.collect()
52
+ if device == "cuda":
53
+ torch.cuda.empty_cache()
54
+
55
+ model_id = available_models[model_key]
56
+ print(f"Loading {model_id}...")
57
+ yield {status_textbox: f"Loading {model_id}... This can take a minute."}
58
+
59
+ # Adapt model loading based on hardware
60
+ if "Image" in model_key:
61
+ pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16" if device == "cuda" else "fp32")
62
+ elif "Video" in model_key:
63
+ pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
64
+
65
+ # Move pipe to the detected device
66
+ pipe.to(device)
67
+
68
+ # CPU offloading only makes sense on a GPU setup
69
+ if device == "cuda" and "Turbo" not in model_key and "Video" not in model_key:
70
+ pipe.enable_model_cpu_offload()
71
+
72
+ model_state["current_pipe"] = pipe
73
+ model_state["loaded_model_name"] = model_key
74
+ print(f"✅ Model loaded successfully on {device.upper()}.")
75
+
76
+ pipe = model_state["current_pipe"]
77
+ generator = torch.Generator(device).manual_seed(seed)
78
+ yield {status_textbox: f"Generating with {model_key} on {device.upper()}..."}
79
+
80
+ if "Image" in model_key:
81
+ print("Generating image...")
82
+ if "Turbo" in model_key:
83
+ num_steps, guidance_scale = 1, 0.0
84
+ else:
85
+ num_steps, guidance_scale = int(steps), float(cfg_scale)
86
+
87
+ image = pipe(
88
+ prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps,
89
+ guidance_scale=guidance_scale, width=int(width), height=int(height), generator=generator
90
+ ).images[0]
91
+ print("✅ Image generation complete.")
92
+ yield {output_image: image, output_video: None, status_textbox: f"Seed used: {seed}"}
93
+
94
+ elif "Video" in model_key:
95
+ print("Generating video...")
96
+ video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
97
+
98
+ video_path = f"/tmp/video_{seed}.mp4"
99
+ imageio.mimsave(video_path, video_frames, fps=12)
100
+ print(f"✅ Video saved to {video_path}")
101
+ yield {output_image: None, output_video: video_path, status_textbox: f"Seed used: {seed}"}
102
+
103
+
104
+ # --- GRADIO USER INTERFACE (No changes needed here) ---
105
+ with gr.Blocks(theme='gradio/soft') as demo:
106
+ gr.Markdown("# The Generative Media Suite")
107
+ # ... (rest of the UI code is identical to before)
108
+ gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182.")
109
+ seed_state = gr.State(-1)
110
+ with gr.Row():
111
+ with gr.Column(scale=2):
112
+ model_selector = gr.Radio(label="Select Model", choices=list(available_models.keys()), value=list(available_models.keys())[0])
113
+ prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="An astronaut riding a horse on Mars, cinematic...")
114
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, deformed, watermark, text")
115
+ with gr.Accordion("Settings", open=True):
116
+ steps_slider = gr.Slider(1, 100, 30, step=1, label="Inference Steps")
117
+ cfg_slider = gr.Slider(0.0, 15.0, 7.5, step=0.5, label="Guidance Scale (CFG)")
118
+ with gr.Row():
119
+ width_slider = gr.Slider(256, 1024, 768, step=64, label="Width")
120
+ height_slider = gr.Slider(256, 1024, 768, step=64, label="Height")
121
+ num_frames_slider = gr.Slider(12, 48, 24, step=4, label="Video Frames", visible=False)
122
+ seed_input = gr.Number(-1, label="Seed (-1 for random)")
123
+ generate_button = gr.Button("Generate", variant="primary")
124
+ with gr.Column(scale=3):
125
+ output_image = gr.Image(label="Image Result", interactive=False, height="60vh", visible=True)
126
+ output_video = gr.Video(label="Video Result", interactive=False, height="60vh", visible=False)
127
+ status_textbox = gr.Textbox(label="Status", interactive=False)
128
+ def update_ui_on_model_change(model_key):
129
+ is_video = "Video" in model_key
130
+ is_turbo = "Turbo" in model_key
131
+ return {
132
+ steps_slider: gr.update(interactive=not is_turbo, value=1 if is_turbo else 30),
133
+ cfg_slider: gr.update(interactive=not is_turbo, value=0.0 if is_turbo else 7.5),
134
+ width_slider: gr.update(visible=not is_video),
135
+ height_slider: gr.update(visible=not is_video),
136
+ num_frames_slider: gr.update(visible=is_video),
137
+ output_image: gr.update(visible=not is_video),
138
+ output_video: gr.update(visible=is_video)
139
+ }
140
+ model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video])
141
+ click_event = generate_button.click(
142
+ fn=lambda s: (s if s != -1 else random.randint(0, 2**32 - 1)),
143
+ inputs=seed_input,
144
+ outputs=seed_state,
145
+ queue=False
146
+ ).then(
147
+ fn=generate_media,
148
+ inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
149
+ outputs=[output_image, output_video, status_textbox]
150
+ )
151
+
152
  demo.launch()