cheeseman182 commited on
Commit
6cc0b7a
·
verified ·
1 Parent(s): 9b3af1f

Update media.py

Browse files
Files changed (1) hide show
  1. media.py +46 -50
media.py CHANGED
@@ -1,5 +1,3 @@
1
- # --- START OF FILE media.py (FINAL WITH LIVE PROGRESS) ---
2
-
3
  # --- LIBRARIES ---
4
  import torch
5
  import gradio as gr
@@ -14,24 +12,21 @@ import threading
14
  from queue import Queue, Empty as QueueEmpty
15
  from PIL import Image
16
 
17
- # --- SECURE AUTHENTICATION FOR HUGGING FACE SPACES ---
18
- import os
19
- from huggingface_hub import login
 
 
 
 
 
 
20
 
21
- # This code will attempt to read the HF_TOKEN from the Space's secrets.
22
- # On your local machine, this will do nothing unless you set it up, which isn't necessary.
23
- # On the Hugging Face server, it will find the secret you just saved.
24
- HF_TOKEN = os.environ.get('HF_TOKEN')
25
 
26
- if HF_TOKEN:
27
- print("✅ Found HF_TOKEN secret. Logging in...")
28
- try:
29
- login(token=HF_TOKEN)
30
- print("✅ Hugging Face Authentication successful.")
31
- except Exception as e:
32
- print(f"❌ Hugging Face login failed: {e}")
33
- else:
34
- print("⚠️ No HF_TOKEN secret found. Gated models may not be available on the deployed app.")
35
 
36
  # --- CONFIGURATION & STATE ---
37
  available_models = {
@@ -42,13 +37,19 @@ available_models = {
42
  }
43
  model_state = { "current_pipe": None, "loaded_model_name": None }
44
 
45
- # --- THE FINAL GENERATION FUNCTION WITH LIVE PROGRESS ---
46
- def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
47
- # --- Model Loading (Unchanged) ---
 
 
48
  if model_state.get("loaded_model_name") != model_key:
49
  yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."}
50
  if model_state.get("current_pipe"):
51
- del model_state["current_pipe"]; gc.collect(); torch.cuda.empty_cache()
 
 
 
 
52
  model_id = available_models[model_key]
53
  if "Video" in model_key:
54
  pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
@@ -70,8 +71,8 @@ def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_
70
 
71
  # --- Generation Logic ---
72
  if "Video" in model_key:
73
- # For video, we'll keep the simple status updates for now
74
  yield {output_image: None, output_video: None, status_textbox: "Generating video..."}
 
75
  video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
76
  video_frames_5d = np.array(video_frames)
77
  video_frames_4d = np.squeeze(video_frames_5d)
@@ -81,77 +82,72 @@ def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_
81
  imageio.mimsave(video_path, list_of_frames, fps=12)
82
  yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"}
83
 
84
- else: # Image Generation with Live Progress
85
  progress_queue = Queue()
86
 
87
  def run_pipe():
88
- # This function runs in a separate thread
89
  start_time = time.time()
90
 
91
- def progress_callback(pipe, step, timestep, callback_kwargs):
92
- # This is called by the pipeline at each step
93
  elapsed_time = time.time() - start_time
94
- # Avoid division by zero on the first step
95
  if elapsed_time > 0:
96
  its_per_sec = (step + 1) / elapsed_time
97
- progress_queue.put((step + 1, its_per_sec))
98
- return callback_kwargs
99
 
100
  try:
101
- # The final image is still generated using the pipeline's high-quality VAE
102
  final_image = pipe(
103
  prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=int(steps),
104
  guidance_scale=float(cfg_scale), width=int(width), height=int(height),
105
  generator=generator,
106
  callback_on_step_end=progress_callback
107
  ).images[0]
108
- progress_queue.put(final_image) # Put the final result on the queue
109
  except Exception as e:
110
  print(f"An error occurred in the generation thread: {e}")
111
- progress_queue.put(None) # Signal an error
112
 
113
- # Start the generation in the background
114
  thread = threading.Thread(target=run_pipe)
115
  thread.start()
116
 
117
- # In the main thread, listen for updates from the queue and yield to Gradio
118
  total_steps = int(steps)
119
- yield {status_textbox: "Generating..."} # Initial status
 
120
 
121
  while True:
122
  try:
123
- update = progress_queue.get(timeout=1.0) # Wait for an update
124
 
125
- if isinstance(update, Image.Image): # It's the final image
126
- yield {output_image: update, status_textbox: f"Generation complete! Seed: {seed}"}
 
127
  break
128
- elif isinstance(update, tuple): # It's a progress update (step, speed)
129
- current_step, its_per_sec = update
130
  progress_percent = (current_step / total_steps) * 100
131
  steps_remaining = total_steps - current_step
132
  eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0
133
  eta_minutes, eta_seconds_rem = divmod(int(eta_seconds), 60)
134
-
135
  status_text = (
136
  f"Generating... {progress_percent:.0f}% ({current_step}/{total_steps}) | "
137
  f"{its_per_sec:.2f}it/s | "
138
  f"ETA: {eta_minutes:02d}:{eta_seconds_rem:02d}"
139
  )
140
  yield {status_textbox: status_text}
141
- elif update is None: # An error occurred
142
- yield {status_textbox: "Error during generation. Check console."}
143
  break
144
  except QueueEmpty:
145
  if not thread.is_alive():
146
- print("⚠️ Generation thread finished unexpectedly.")
147
  yield {status_textbox: "Generation failed. Check console for details."}
148
  break
149
 
150
  thread.join()
151
 
152
- # --- GRADIO UI ---
153
  with gr.Blocks(theme='gradio/soft') as demo:
154
- # (UI layout is the same, just point to the new function)
155
  gr.Markdown("# The Generative Media Suite")
156
  gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)")
157
  seed_state = gr.State(-1)
@@ -159,7 +155,7 @@ with gr.Blocks(theme='gradio/soft') as demo:
159
  with gr.Column(scale=2):
160
  model_selector = gr.Radio(label="Select Model", choices=list(available_models.keys()), value=list(available_models.keys())[0])
161
  prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="An astronaut riding a horse on Mars, cinematic...")
162
- negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, deformed, watermark, text")
163
  with gr.Accordion("Settings", open=True):
164
  steps_slider = gr.Slider(1, 100, 30, step=1, label="Inference Steps")
165
  cfg_slider = gr.Slider(0.0, 15.0, 7.5, step=0.5, label="Guidance Scale (CFG)")
@@ -194,9 +190,9 @@ with gr.Blocks(theme='gradio/soft') as demo:
194
  outputs=seed_state,
195
  queue=False
196
  ).then(
197
- fn=generate_media_live_progress, # Use the new function with progress
198
  inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
199
  outputs=[output_image, output_video, status_textbox]
200
  )
201
 
202
- demo.launch()
 
 
 
1
  # --- LIBRARIES ---
2
  import torch
3
  import gradio as gr
 
12
  from queue import Queue, Empty as QueueEmpty
13
  from PIL import Image
14
 
15
+ # --- DYNAMIC HARDWARE DETECTION & AUTH ---
16
+ if torch.cuda.is_available():
17
+ device = "cuda"
18
+ torch_dtype = torch.float16
19
+ print("✅ GPU detected. Using CUDA.")
20
+ else:
21
+ device = "cpu"
22
+ torch_dtype = torch.float32
23
+ print("⚠️ No GPU detected.")
24
 
25
+ HF_TOKEN = os.getenv("HF_TOKEN") # Will read the token from Space secrets
26
+ if HF_TOKEN is None:
27
+ raise ValueError("❌ HF_TOKEN is not set in the environment variables!")
 
28
 
29
+ login(token=HF_TOKEN)
 
 
 
 
 
 
 
 
30
 
31
  # --- CONFIGURATION & STATE ---
32
  available_models = {
 
37
  }
38
  model_state = { "current_pipe": None, "loaded_model_name": None }
39
 
40
+ # --- THE FINAL, STABLE GENERATION FUNCTION ---
41
+ def generate_media_with_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
42
+ global model_state
43
+
44
+ # --- Model Loading ---
45
  if model_state.get("loaded_model_name") != model_key:
46
  yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."}
47
  if model_state.get("current_pipe"):
48
+ pipe_to_delete = model_state.pop("current_pipe", None)
49
+ if pipe_to_delete: del pipe_to_delete
50
+ gc.collect()
51
+ torch.cuda.empty_cache()
52
+
53
  model_id = available_models[model_key]
54
  if "Video" in model_key:
55
  pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
 
71
 
72
  # --- Generation Logic ---
73
  if "Video" in model_key:
 
74
  yield {output_image: None, output_video: None, status_textbox: "Generating video..."}
75
+ # (Your working video code)
76
  video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
77
  video_frames_5d = np.array(video_frames)
78
  video_frames_4d = np.squeeze(video_frames_5d)
 
82
  imageio.mimsave(video_path, list_of_frames, fps=12)
83
  yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"}
84
 
85
+ else: # Image Generation with your brilliant text-based progress bar
86
  progress_queue = Queue()
87
 
88
  def run_pipe():
 
89
  start_time = time.time()
90
 
91
+ # This callback correctly accepts all arguments
92
+ def progress_callback(step, timestep, latents, **kwargs):
93
  elapsed_time = time.time() - start_time
 
94
  if elapsed_time > 0:
95
  its_per_sec = (step + 1) / elapsed_time
96
+ progress_queue.put(("progress", step + 1, its_per_sec))
97
+ return kwargs
98
 
99
  try:
 
100
  final_image = pipe(
101
  prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=int(steps),
102
  guidance_scale=float(cfg_scale), width=int(width), height=int(height),
103
  generator=generator,
104
  callback_on_step_end=progress_callback
105
  ).images[0]
106
+ progress_queue.put(("final", final_image))
107
  except Exception as e:
108
  print(f"An error occurred in the generation thread: {e}")
109
+ progress_queue.put(("error", str(e)))
110
 
 
111
  thread = threading.Thread(target=run_pipe)
112
  thread.start()
113
 
 
114
  total_steps = int(steps)
115
+ final_image_result = None
116
+ yield {status_textbox: "Generating..."}
117
 
118
  while True:
119
  try:
120
+ update_type, data = progress_queue.get(timeout=1.0)
121
 
122
+ if update_type == "final":
123
+ final_image_result = data
124
+ yield {output_image: final_image_result, status_textbox: f"Generation complete! Seed: {seed}"}
125
  break
126
+ elif update_type == "progress":
127
+ current_step, its_per_sec = data
128
  progress_percent = (current_step / total_steps) * 100
129
  steps_remaining = total_steps - current_step
130
  eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0
131
  eta_minutes, eta_seconds_rem = divmod(int(eta_seconds), 60)
 
132
  status_text = (
133
  f"Generating... {progress_percent:.0f}% ({current_step}/{total_steps}) | "
134
  f"{its_per_sec:.2f}it/s | "
135
  f"ETA: {eta_minutes:02d}:{eta_seconds_rem:02d}"
136
  )
137
  yield {status_textbox: status_text}
138
+ elif update_type == "error":
139
+ yield {status_textbox: f"Error: {data}"}
140
  break
141
  except QueueEmpty:
142
  if not thread.is_alive():
 
143
  yield {status_textbox: "Generation failed. Check console for details."}
144
  break
145
 
146
  thread.join()
147
 
148
+ # --- GRADIO UI (Unchanged) ---
149
  with gr.Blocks(theme='gradio/soft') as demo:
150
+ # (Your UI code is perfect)
151
  gr.Markdown("# The Generative Media Suite")
152
  gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)")
153
  seed_state = gr.State(-1)
 
155
  with gr.Column(scale=2):
156
  model_selector = gr.Radio(label="Select Model", choices=list(available_models.keys()), value=list(available_models.keys())[0])
157
  prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="An astronaut riding a horse on Mars, cinematic...")
158
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, deformed, watermark, text, overblown, high contrast, not photorealistic")
159
  with gr.Accordion("Settings", open=True):
160
  steps_slider = gr.Slider(1, 100, 30, step=1, label="Inference Steps")
161
  cfg_slider = gr.Slider(0.0, 15.0, 7.5, step=0.5, label="Guidance Scale (CFG)")
 
190
  outputs=seed_state,
191
  queue=False
192
  ).then(
193
+ fn=generate_media_with_progress,
194
  inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
195
  outputs=[output_image, output_video, status_textbox]
196
  )
197
 
198
+ demo.launch(share=True)