Spanicin commited on
Commit
2cec91c
Β·
verified Β·
1 Parent(s): f5fea08

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -149
app.py CHANGED
@@ -390,42 +390,6 @@ DEVICE = None
390
  CONFIG = None
391
 
392
 
393
- def save_to_hub(save_name, repo_id=None):
394
- """Save model checkpoint to HuggingFace Hub for persistence."""
395
- global MODEL, TEXT_ENCODER, CONFIG
396
-
397
- if MODEL is None:
398
- return "❌ No model loaded to save"
399
-
400
- try:
401
- from huggingface_hub import HfApi, upload_file
402
- import tempfile
403
-
404
- # Save to temp file
405
- with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f:
406
- torch.save({
407
- "model_state_dict": MODEL.state_dict(),
408
- "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
409
- "config": CONFIG
410
- }, f.name)
411
- temp_path = f.name
412
-
413
- # Upload to Hub (same Space repo)
414
- api = HfApi()
415
- api.upload_file(
416
- path_or_fileobj=temp_path,
417
- path_in_repo=f"checkpoints/{save_name}.pt",
418
- repo_id=repo_id or "Spanicin/candlestick-diffusion",
419
- repo_type="space"
420
- )
421
-
422
- os.unlink(temp_path)
423
- return f"βœ… Model saved to Hub: checkpoints/{save_name}.pt"
424
-
425
- except Exception as e:
426
- return f"❌ Failed to save to Hub: {str(e)}"
427
-
428
-
429
  def load_model(checkpoint_path=None):
430
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
431
 
@@ -478,33 +442,25 @@ def generate_dataset_ui(num_samples, image_size):
478
  import os
479
  import json
480
  import random
481
- from PIL import Image, ImageDraw
 
 
 
482
 
483
  output_dir = "./dataset"
484
  os.makedirs(output_dir, exist_ok=True)
485
  os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
486
 
 
 
 
487
  num_candles = 20
488
- labels = {}
489
-
490
- # Colors
491
- bg_color = (26, 26, 46)
492
- bullish_color = (0, 255, 136)
493
- bearish_color = (255, 68, 102)
494
-
495
- patterns = ["bullish", "bearish", "sideways"]
496
- volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
497
 
498
- for i in range(int(num_samples)):
499
- # Generate candle data
500
- pattern = random.choice(patterns)
501
- vol_name = random.choice(list(volatilities.keys()))
502
- vol = volatilities[vol_name]
503
-
504
  candles = []
505
  price = 100 if pattern != "bearish" else 150
506
 
507
- for j in range(num_candles):
508
  if pattern == "bullish":
509
  trend = random.uniform(0.5, 2.0)
510
  o = price + random.gauss(0, vol)
@@ -513,7 +469,7 @@ def generate_dataset_ui(num_samples, image_size):
513
  trend = random.uniform(0.5, 2.0)
514
  o = price + random.gauss(0, vol)
515
  c = o - random.uniform(0, vol*2) - trend
516
- else:
517
  o = price + random.gauss(0, vol)
518
  c = o + random.gauss(0, vol)
519
 
@@ -521,62 +477,63 @@ def generate_dataset_ui(num_samples, image_size):
521
  l = min(o, c) - random.uniform(0, vol)
522
  candles.append({"o": o, "h": h, "l": l, "c": c})
523
  price = c
524
-
525
- # Render with PIL (no matplotlib)
526
- img = Image.new('RGB', (image_size, image_size), bg_color)
527
- draw = ImageDraw.Draw(img)
 
 
528
 
529
  highs = [c["h"] for c in candles]
530
  lows = [c["l"] for c in candles]
531
- price_min, price_max = min(lows), max(highs)
532
- price_range = price_max - price_min or 1
533
-
534
- padding = 10
535
- chart_width = image_size - 2 * padding
536
- chart_height = image_size - 2 * padding
537
- candle_width = chart_width // (num_candles + 2)
538
 
539
- for j, c in enumerate(candles):
540
- x = padding + (j + 1) * candle_width
541
-
542
- # Scale prices to pixels
543
- def scale_y(p):
544
- return padding + chart_height - int((p - price_min) / price_range * chart_height)
545
-
546
- y_open = scale_y(c["o"])
547
- y_close = scale_y(c["c"])
548
- y_high = scale_y(c["h"])
549
- y_low = scale_y(c["l"])
550
-
551
  color = bullish_color if c["c"] >= c["o"] else bearish_color
552
-
553
- # Draw wick
554
- draw.line([(x, y_high), (x, y_low)], fill=color, width=1)
555
-
556
- # Draw body
557
- body_top = min(y_open, y_close)
558
- body_bottom = max(y_open, y_close)
559
- if body_bottom - body_top < 2:
560
- body_bottom = body_top + 2
561
- draw.rectangle(
562
- [(x - candle_width//3, body_top), (x + candle_width//3, body_bottom)],
563
- fill=color
564
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
- # Save image
567
  filename = f"chart_{i:06d}.png"
568
  img.save(os.path.join(output_dir, "images", filename))
569
  labels[filename] = f"{pattern} trend {vol_name} volatility"
 
 
 
570
 
571
- # Save labels
572
  with open(os.path.join(output_dir, "labels.json"), "w") as f:
573
  json.dump(labels, f)
574
 
575
- return f"βœ… Generated {int(num_samples)} samples in ./dataset"
576
 
577
  except Exception as e:
578
- import traceback
579
- return f"❌ Failed: {str(e)}\n{traceback.format_exc()}"
580
 
581
 
582
  # ============== Gradio Interface ==============
@@ -625,20 +582,12 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
625
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
626
 
627
  try:
628
- # Clear GPU memory
629
- import gc
630
- gc.collect()
631
- if torch.cuda.is_available():
632
- torch.cuda.empty_cache()
633
-
634
  # Setup
635
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
636
-
637
- # Use smaller model for T4 GPU
638
  CONFIG = {
639
- "base_channels": 48, # Reduced from 64
640
- "channel_mults": (1, 2, 4), # Keep same
641
- "context_dim": 192, # Reduced from 256
642
  "image_size": image_size,
643
  "timesteps": 1000
644
  }
@@ -674,14 +623,13 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
674
  logs = [f"πŸš€ Training started on {DEVICE}"]
675
  logs.append(f"πŸ“Š Model parameters: {num_params:,}")
676
  logs.append(f"πŸ“ Training samples: {len(train_dataset)}")
677
- logs.append(f"πŸ–ΌοΈ Image size: {image_size}x{image_size}")
678
- logs.append(f"πŸ“¦ Batch size: {batch_size}")
679
  logs.append("-" * 40)
680
 
681
- for epoch in range(int(epochs)):
 
 
 
682
  epoch_loss = 0
683
- batch_count = 0
684
-
685
  for images, texts in train_loader:
686
  images = images.to(DEVICE)
687
  context = TEXT_ENCODER(texts, DEVICE)
@@ -693,28 +641,10 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
693
  optimizer.step()
694
 
695
  epoch_loss += loss.item()
696
- batch_count += 1
697
-
698
- # Clear cache periodically
699
- if batch_count % 50 == 0:
700
- if torch.cuda.is_available():
701
- torch.cuda.empty_cache()
702
 
703
- avg_loss = epoch_loss / max(batch_count, 1)
704
- logs.append(f"Epoch {epoch+1}/{int(epochs)}: loss = {avg_loss:.4f}")
705
-
706
- # Save checkpoint every 10 epochs
707
- if (epoch + 1) % 10 == 0:
708
- print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")
709
- checkpoint_path = f"checkpoints/{save_name}_epoch{epoch+1}.pt"
710
- os.makedirs("checkpoints", exist_ok=True)
711
- torch.save({
712
- "model_state_dict": MODEL.state_dict(),
713
- "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
714
- "config": CONFIG,
715
- "epoch": epoch + 1
716
- }, checkpoint_path)
717
- logs.append(f"πŸ’Ύ Checkpoint saved: {checkpoint_path}")
718
 
719
  # Save model
720
  MODEL.eval()
@@ -729,26 +659,10 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
729
  logs.append("-" * 40)
730
  logs.append(f"βœ… Model saved to {save_path}")
731
 
732
- # Also try to save to Hub for persistence
733
- try:
734
- from huggingface_hub import HfApi
735
- api = HfApi()
736
- api.upload_file(
737
- path_or_fileobj=save_path,
738
- path_in_repo=f"checkpoints/{save_name}.pt",
739
- repo_id="Spanicin/candlestick-diffusion",
740
- repo_type="space"
741
- )
742
- logs.append(f"☁️ Model uploaded to Hub (persistent)")
743
- except Exception as hub_error:
744
- logs.append(f"⚠️ Could not upload to Hub: {hub_error}")
745
- logs.append(" Model saved locally but may be lost on restart")
746
-
747
  return "\n".join(logs)
748
 
749
  except Exception as e:
750
- import traceback
751
- return f"❌ Training failed: {str(e)}\n{traceback.format_exc()}"
752
 
753
 
754
  def load_checkpoint(checkpoint_file):
 
390
  CONFIG = None
391
 
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  def load_model(checkpoint_path=None):
394
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
395
 
 
442
  import os
443
  import json
444
  import random
445
+ from PIL import Image
446
+ import matplotlib.pyplot as plt
447
+ from matplotlib.patches import Rectangle
448
+ import io
449
 
450
  output_dir = "./dataset"
451
  os.makedirs(output_dir, exist_ok=True)
452
  os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
453
 
454
+ bg_color = "#1a1a2e"
455
+ bullish_color = "#00ff88"
456
+ bearish_color = "#ff4466"
457
  num_candles = 20
 
 
 
 
 
 
 
 
 
458
 
459
+ def generate_candles(pattern, vol):
 
 
 
 
 
460
  candles = []
461
  price = 100 if pattern != "bearish" else 150
462
 
463
+ for i in range(num_candles):
464
  if pattern == "bullish":
465
  trend = random.uniform(0.5, 2.0)
466
  o = price + random.gauss(0, vol)
 
469
  trend = random.uniform(0.5, 2.0)
470
  o = price + random.gauss(0, vol)
471
  c = o - random.uniform(0, vol*2) - trend
472
+ else: # sideways
473
  o = price + random.gauss(0, vol)
474
  c = o + random.gauss(0, vol)
475
 
 
477
  l = min(o, c) - random.uniform(0, vol)
478
  candles.append({"o": o, "h": h, "l": l, "c": c})
479
  price = c
480
+ return candles
481
+
482
+ def render(candles):
483
+ fig, ax = plt.subplots(figsize=(image_size/100, image_size/100), dpi=100)
484
+ fig.patch.set_facecolor(bg_color)
485
+ ax.set_facecolor(bg_color)
486
 
487
  highs = [c["h"] for c in candles]
488
  lows = [c["l"] for c in candles]
489
+ price_min, price_max = min(lows)*0.98, max(highs)*1.02
 
 
 
 
 
 
490
 
491
+ for i, c in enumerate(candles):
 
 
 
 
 
 
 
 
 
 
 
492
  color = bullish_color if c["c"] >= c["o"] else bearish_color
493
+ ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1)
494
+ body_bottom = min(c["o"], c["c"])
495
+ body_height = abs(c["c"] - c["o"]) or 0.1
496
+ rect = Rectangle((i-0.3, body_bottom), 0.6, body_height, facecolor=color)
497
+ ax.add_patch(rect)
498
+
499
+ ax.set_xlim(-1, len(candles))
500
+ ax.set_ylim(price_min, price_max)
501
+ ax.axis("off")
502
+
503
+ buf = io.BytesIO()
504
+ plt.savefig(buf, format="png", facecolor=bg_color, bbox_inches="tight", pad_inches=0.1)
505
+ plt.close(fig)
506
+ buf.seek(0)
507
+
508
+ img = Image.open(buf).convert("RGB")
509
+ return img.resize((image_size, image_size), Image.Resampling.LANCZOS)
510
+
511
+ patterns = ["bullish", "bearish", "sideways"]
512
+ volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
513
+ labels = {}
514
+
515
+ for i in range(int(num_samples)):
516
+ pattern = random.choice(patterns)
517
+ vol_name = random.choice(list(volatilities.keys()))
518
+ vol = volatilities[vol_name]
519
+
520
+ candles = generate_candles(pattern, vol)
521
+ img = render(candles)
522
 
 
523
  filename = f"chart_{i:06d}.png"
524
  img.save(os.path.join(output_dir, "images", filename))
525
  labels[filename] = f"{pattern} trend {vol_name} volatility"
526
+
527
+ if i % 500 == 0:
528
+ print(f"Generated {i}/{num_samples}")
529
 
 
530
  with open(os.path.join(output_dir, "labels.json"), "w") as f:
531
  json.dump(labels, f)
532
 
533
+ return f"βœ… Generated {num_samples} samples in ./dataset"
534
 
535
  except Exception as e:
536
+ return f"❌ Failed: {str(e)}"
 
537
 
538
 
539
  # ============== Gradio Interface ==============
 
582
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
583
 
584
  try:
 
 
 
 
 
 
585
  # Setup
586
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
587
  CONFIG = {
588
+ "base_channels": 64,
589
+ "channel_mults": (1, 2, 4),
590
+ "context_dim": 256,
591
  "image_size": image_size,
592
  "timesteps": 1000
593
  }
 
623
  logs = [f"πŸš€ Training started on {DEVICE}"]
624
  logs.append(f"πŸ“Š Model parameters: {num_params:,}")
625
  logs.append(f"πŸ“ Training samples: {len(train_dataset)}")
 
 
626
  logs.append("-" * 40)
627
 
628
+ total_steps = epochs * len(train_loader)
629
+ current_step = 0
630
+
631
+ for epoch in range(epochs):
632
  epoch_loss = 0
 
 
633
  for images, texts in train_loader:
634
  images = images.to(DEVICE)
635
  context = TEXT_ENCODER(texts, DEVICE)
 
641
  optimizer.step()
642
 
643
  epoch_loss += loss.item()
644
+ current_step += 1
 
 
 
 
 
645
 
646
+ avg_loss = epoch_loss / len(train_loader)
647
+ logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
648
 
649
  # Save model
650
  MODEL.eval()
 
659
  logs.append("-" * 40)
660
  logs.append(f"βœ… Model saved to {save_path}")
661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  return "\n".join(logs)
663
 
664
  except Exception as e:
665
+ return f"❌ Training failed: {str(e)}"
 
666
 
667
 
668
  def load_checkpoint(checkpoint_file):