Spanicin commited on
Commit
a976fe7
·
verified ·
1 Parent(s): 0f04521

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +4 -8
  2. requirements.txt +1 -2
app.py CHANGED
@@ -438,7 +438,7 @@ def load_model(checkpoint_path=None):
438
 
439
  # ============== Gradio Interface ==============
440
 
441
- def generate_chart(prompt, num_steps, guidance_scale, seed, progress=gr.Progress()):
442
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
443
 
444
  if MODEL is None:
@@ -453,9 +453,6 @@ def generate_chart(prompt, num_steps, guidance_scale, seed, progress=gr.Progress
453
  if DEVICE.type == "cuda":
454
  torch.cuda.manual_seed(seed)
455
 
456
- def update_progress(p):
457
- progress(p, desc="Generating...")
458
-
459
  with torch.no_grad():
460
  context = TEXT_ENCODER([prompt], DEVICE)
461
  context_uncond = TEXT_ENCODER.get_uncond(1, DEVICE)
@@ -465,7 +462,7 @@ def generate_chart(prompt, num_steps, guidance_scale, seed, progress=gr.Progress
465
  shape=(1, 3, CONFIG["image_size"], CONFIG["image_size"]),
466
  steps=num_steps,
467
  guidance_scale=guidance_scale,
468
- progress_callback=update_progress
469
  )
470
 
471
  # Convert to image
@@ -481,7 +478,7 @@ def generate_chart(prompt, num_steps, guidance_scale, seed, progress=gr.Progress
481
  return None, f"❌ Error: {str(e)}"
482
 
483
 
484
- def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_name, progress=gr.Progress()):
485
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
486
 
487
  try:
@@ -545,7 +542,6 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
545
 
546
  epoch_loss += loss.item()
547
  current_step += 1
548
- progress(current_step / total_steps, desc=f"Epoch {epoch+1}/{epochs}")
549
 
550
  avg_loss = epoch_loss / len(train_loader)
551
  logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
@@ -583,7 +579,7 @@ def load_checkpoint(checkpoint_file):
583
  # ============== Gradio UI ==============
584
 
585
  def create_demo():
586
- with gr.Blocks(title="Candlestick Chart Generator", theme=gr.themes.Soft()) as demo:
587
  gr.Markdown("""
588
  # 📈 Candlestick Chart Diffusion Generator
589
 
 
438
 
439
  # ============== Gradio Interface ==============
440
 
441
+ def generate_chart(prompt, num_steps, guidance_scale, seed):
442
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
443
 
444
  if MODEL is None:
 
453
  if DEVICE.type == "cuda":
454
  torch.cuda.manual_seed(seed)
455
 
 
 
 
456
  with torch.no_grad():
457
  context = TEXT_ENCODER([prompt], DEVICE)
458
  context_uncond = TEXT_ENCODER.get_uncond(1, DEVICE)
 
462
  shape=(1, 3, CONFIG["image_size"], CONFIG["image_size"]),
463
  steps=num_steps,
464
  guidance_scale=guidance_scale,
465
+ progress_callback=None
466
  )
467
 
468
  # Convert to image
 
478
  return None, f"❌ Error: {str(e)}"
479
 
480
 
481
+ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_name):
482
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
483
 
484
  try:
 
542
 
543
  epoch_loss += loss.item()
544
  current_step += 1
 
545
 
546
  avg_loss = epoch_loss / len(train_loader)
547
  logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
 
579
  # ============== Gradio UI ==============
580
 
581
  def create_demo():
582
+ with gr.Blocks(title="Candlestick Chart Generator") as demo:
583
  gr.Markdown("""
584
  # 📈 Candlestick Chart Diffusion Generator
585
 
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  torch>=2.0.0
2
  torchvision>=0.15.0
3
- gradio==4.31.0
4
- huggingface_hub>=0.22.0
5
  Pillow>=9.5.0
6
  numpy>=1.24.0
7
  einops>=0.6.1
 
1
  torch>=2.0.0
2
  torchvision>=0.15.0
3
+ gradio==3.50.2
 
4
  Pillow>=9.5.0
5
  numpy>=1.24.0
6
  einops>=0.6.1