Spanicin commited on
Commit
0a1421d
Β·
verified Β·
1 Parent(s): 457f91c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py CHANGED
@@ -436,6 +436,106 @@ def load_model(checkpoint_path=None):
436
  return True
437
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  # ============== Gradio Interface ==============
440
 
441
  def generate_chart(prompt, num_steps, guidance_scale, seed):
@@ -592,6 +692,30 @@ def create_demo():
592
  """)
593
 
594
  with gr.Tabs():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  # Generation Tab
596
  with gr.TabItem("🎨 Generate"):
597
  with gr.Row():
 
436
  return True
437
 
438
 
439
+ def generate_dataset_ui(num_samples, image_size):
440
+ """Generate training dataset."""
441
+ try:
442
+ import os
443
+ import json
444
+ import random
445
+ from PIL import Image
446
+ import matplotlib.pyplot as plt
447
+ from matplotlib.patches import Rectangle
448
+ import io
449
+
450
+ output_dir = "./dataset"
451
+ os.makedirs(output_dir, exist_ok=True)
452
+ os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
453
+
454
+ bg_color = "#1a1a2e"
455
+ bullish_color = "#00ff88"
456
+ bearish_color = "#ff4466"
457
+ num_candles = 20
458
+
459
+ def generate_candles(pattern, vol):
460
+ candles = []
461
+ price = 100 if pattern != "bearish" else 150
462
+
463
+ for i in range(num_candles):
464
+ if pattern == "bullish":
465
+ trend = random.uniform(0.5, 2.0)
466
+ o = price + random.gauss(0, vol)
467
+ c = o + random.uniform(0, vol*2) + trend
468
+ elif pattern == "bearish":
469
+ trend = random.uniform(0.5, 2.0)
470
+ o = price + random.gauss(0, vol)
471
+ c = o - random.uniform(0, vol*2) - trend
472
+ else: # sideways
473
+ o = price + random.gauss(0, vol)
474
+ c = o + random.gauss(0, vol)
475
+
476
+ h = max(o, c) + random.uniform(0, vol)
477
+ l = min(o, c) - random.uniform(0, vol)
478
+ candles.append({"o": o, "h": h, "l": l, "c": c})
479
+ price = c
480
+ return candles
481
+
482
+ def render(candles):
483
+ fig, ax = plt.subplots(figsize=(image_size/100, image_size/100), dpi=100)
484
+ fig.patch.set_facecolor(bg_color)
485
+ ax.set_facecolor(bg_color)
486
+
487
+ highs = [c["h"] for c in candles]
488
+ lows = [c["l"] for c in candles]
489
+ price_min, price_max = min(lows)*0.98, max(highs)*1.02
490
+
491
+ for i, c in enumerate(candles):
492
+ color = bullish_color if c["c"] >= c["o"] else bearish_color
493
+ ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1)
494
+ body_bottom = min(c["o"], c["c"])
495
+ body_height = abs(c["c"] - c["o"]) or 0.1
496
+ rect = Rectangle((i-0.3, body_bottom), 0.6, body_height, facecolor=color)
497
+ ax.add_patch(rect)
498
+
499
+ ax.set_xlim(-1, len(candles))
500
+ ax.set_ylim(price_min, price_max)
501
+ ax.axis("off")
502
+
503
+ buf = io.BytesIO()
504
+ plt.savefig(buf, format="png", facecolor=bg_color, bbox_inches="tight", pad_inches=0.1)
505
+ plt.close(fig)
506
+ buf.seek(0)
507
+
508
+ img = Image.open(buf).convert("RGB")
509
+ return img.resize((image_size, image_size), Image.Resampling.LANCZOS)
510
+
511
+ patterns = ["bullish", "bearish", "sideways"]
512
+ volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
513
+ labels = {}
514
+
515
+ for i in range(int(num_samples)):
516
+ pattern = random.choice(patterns)
517
+ vol_name = random.choice(list(volatilities.keys()))
518
+ vol = volatilities[vol_name]
519
+
520
+ candles = generate_candles(pattern, vol)
521
+ img = render(candles)
522
+
523
+ filename = f"chart_{i:06d}.png"
524
+ img.save(os.path.join(output_dir, "images", filename))
525
+ labels[filename] = f"{pattern} trend {vol_name} volatility"
526
+
527
+ if i % 500 == 0:
528
+ print(f"Generated {i}/{num_samples}")
529
+
530
+ with open(os.path.join(output_dir, "labels.json"), "w") as f:
531
+ json.dump(labels, f)
532
+
533
+ return f"βœ… Generated {num_samples} samples in ./dataset"
534
+
535
+ except Exception as e:
536
+ return f"❌ Failed: {str(e)}"
537
+
538
+
539
  # ============== Gradio Interface ==============
540
 
541
  def generate_chart(prompt, num_steps, guidance_scale, seed):
 
692
  """)
693
 
694
  with gr.Tabs():
695
+ # Data Generation Tab
696
+ with gr.TabItem("πŸ“Š Generate Data"):
697
+ gr.Markdown("""
698
+ ### Generate Training Dataset
699
+
700
+ Create synthetic candlestick chart images for training.
701
+ **Run this first before training!**
702
+ """)
703
+
704
+ with gr.Row():
705
+ with gr.Column():
706
+ num_samples = gr.Slider(1000, 50000, value=10000, step=1000, label="Number of Samples")
707
+ data_image_size = gr.Slider(64, 256, value=128, step=32, label="Image Size")
708
+ generate_data_btn = gr.Button("πŸ“Š Generate Dataset", variant="primary")
709
+
710
+ with gr.Column():
711
+ data_status = gr.Textbox(label="Status", lines=5, interactive=False)
712
+
713
+ generate_data_btn.click(
714
+ generate_dataset_ui,
715
+ inputs=[num_samples, data_image_size],
716
+ outputs=[data_status]
717
+ )
718
+
719
  # Generation Tab
720
  with gr.TabItem("🎨 Generate"):
721
  with gr.Row():