Spanicin commited on
Commit
f5fea08
·
verified ·
1 Parent(s): f64e771

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -52
app.py CHANGED
@@ -478,25 +478,33 @@ def generate_dataset_ui(num_samples, image_size):
478
  import os
479
  import json
480
  import random
481
- from PIL import Image
482
- import matplotlib.pyplot as plt
483
- from matplotlib.patches import Rectangle
484
- import io
485
 
486
  output_dir = "./dataset"
487
  os.makedirs(output_dir, exist_ok=True)
488
  os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
489
 
490
- bg_color = "#1a1a2e"
491
- bullish_color = "#00ff88"
492
- bearish_color = "#ff4466"
493
  num_candles = 20
 
 
 
 
 
 
 
 
 
494
 
495
- def generate_candles(pattern, vol):
 
 
 
 
 
496
  candles = []
497
  price = 100 if pattern != "bearish" else 150
498
 
499
- for i in range(num_candles):
500
  if pattern == "bullish":
501
  trend = random.uniform(0.5, 2.0)
502
  o = price + random.gauss(0, vol)
@@ -505,7 +513,7 @@ def generate_dataset_ui(num_samples, image_size):
505
  trend = random.uniform(0.5, 2.0)
506
  o = price + random.gauss(0, vol)
507
  c = o - random.uniform(0, vol*2) - trend
508
- else: # sideways
509
  o = price + random.gauss(0, vol)
510
  c = o + random.gauss(0, vol)
511
 
@@ -513,63 +521,62 @@ def generate_dataset_ui(num_samples, image_size):
513
  l = min(o, c) - random.uniform(0, vol)
514
  candles.append({"o": o, "h": h, "l": l, "c": c})
515
  price = c
516
- return candles
517
-
518
- def render(candles):
519
- fig, ax = plt.subplots(figsize=(image_size/100, image_size/100), dpi=100)
520
- fig.patch.set_facecolor(bg_color)
521
- ax.set_facecolor(bg_color)
522
 
523
  highs = [c["h"] for c in candles]
524
  lows = [c["l"] for c in candles]
525
- price_min, price_max = min(lows)*0.98, max(highs)*1.02
526
-
527
- for i, c in enumerate(candles):
528
- color = bullish_color if c["c"] >= c["o"] else bearish_color
529
- ax.plot([i, i], [c["l"], c["h"]], color=color, linewidth=1)
530
- body_bottom = min(c["o"], c["c"])
531
- body_height = abs(c["c"] - c["o"]) or 0.1
532
- rect = Rectangle((i-0.3, body_bottom), 0.6, body_height, facecolor=color)
533
- ax.add_patch(rect)
534
-
535
- ax.set_xlim(-1, len(candles))
536
- ax.set_ylim(price_min, price_max)
537
- ax.axis("off")
538
 
539
- buf = io.BytesIO()
540
- plt.savefig(buf, format="png", facecolor=bg_color, bbox_inches="tight", pad_inches=0.1)
541
- plt.close(fig)
542
- buf.seek(0)
543
-
544
- img = Image.open(buf).convert("RGB")
545
- return img.resize((image_size, image_size), Image.Resampling.LANCZOS)
546
-
547
- patterns = ["bullish", "bearish", "sideways"]
548
- volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
549
- labels = {}
550
-
551
- for i in range(int(num_samples)):
552
- pattern = random.choice(patterns)
553
- vol_name = random.choice(list(volatilities.keys()))
554
- vol = volatilities[vol_name]
555
 
556
- candles = generate_candles(pattern, vol)
557
- img = render(candles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
 
 
559
  filename = f"chart_{i:06d}.png"
560
  img.save(os.path.join(output_dir, "images", filename))
561
  labels[filename] = f"{pattern} trend {vol_name} volatility"
562
-
563
- if i % 500 == 0:
564
- print(f"Generated {i}/{num_samples}")
565
 
 
566
  with open(os.path.join(output_dir, "labels.json"), "w") as f:
567
  json.dump(labels, f)
568
 
569
- return f"✅ Generated {num_samples} samples in ./dataset"
570
 
571
  except Exception as e:
572
- return f"❌ Failed: {str(e)}"
 
573
 
574
 
575
  # ============== Gradio Interface ==============
 
478
  import os
479
  import json
480
  import random
481
+ from PIL import Image, ImageDraw
 
 
 
482
 
483
  output_dir = "./dataset"
484
  os.makedirs(output_dir, exist_ok=True)
485
  os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
486
 
 
 
 
487
  num_candles = 20
488
+ labels = {}
489
+
490
+ # Colors
491
+ bg_color = (26, 26, 46)
492
+ bullish_color = (0, 255, 136)
493
+ bearish_color = (255, 68, 102)
494
+
495
+ patterns = ["bullish", "bearish", "sideways"]
496
+ volatilities = {"low": 1.0, "medium": 3.0, "high": 6.0}
497
 
498
+ for i in range(int(num_samples)):
499
+ # Generate candle data
500
+ pattern = random.choice(patterns)
501
+ vol_name = random.choice(list(volatilities.keys()))
502
+ vol = volatilities[vol_name]
503
+
504
  candles = []
505
  price = 100 if pattern != "bearish" else 150
506
 
507
+ for j in range(num_candles):
508
  if pattern == "bullish":
509
  trend = random.uniform(0.5, 2.0)
510
  o = price + random.gauss(0, vol)
 
513
  trend = random.uniform(0.5, 2.0)
514
  o = price + random.gauss(0, vol)
515
  c = o - random.uniform(0, vol*2) - trend
516
+ else:
517
  o = price + random.gauss(0, vol)
518
  c = o + random.gauss(0, vol)
519
 
 
521
  l = min(o, c) - random.uniform(0, vol)
522
  candles.append({"o": o, "h": h, "l": l, "c": c})
523
  price = c
524
+
525
+ # Render with PIL (no matplotlib)
526
+ img = Image.new('RGB', (image_size, image_size), bg_color)
527
+ draw = ImageDraw.Draw(img)
 
 
528
 
529
  highs = [c["h"] for c in candles]
530
  lows = [c["l"] for c in candles]
531
+ price_min, price_max = min(lows), max(highs)
532
+ price_range = price_max - price_min or 1
 
 
 
 
 
 
 
 
 
 
 
533
 
534
+ padding = 10
535
+ chart_width = image_size - 2 * padding
536
+ chart_height = image_size - 2 * padding
537
+ candle_width = chart_width // (num_candles + 2)
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
+ for j, c in enumerate(candles):
540
+ x = padding + (j + 1) * candle_width
541
+
542
+ # Scale prices to pixels
543
+ def scale_y(p):
544
+ return padding + chart_height - int((p - price_min) / price_range * chart_height)
545
+
546
+ y_open = scale_y(c["o"])
547
+ y_close = scale_y(c["c"])
548
+ y_high = scale_y(c["h"])
549
+ y_low = scale_y(c["l"])
550
+
551
+ color = bullish_color if c["c"] >= c["o"] else bearish_color
552
+
553
+ # Draw wick
554
+ draw.line([(x, y_high), (x, y_low)], fill=color, width=1)
555
+
556
+ # Draw body
557
+ body_top = min(y_open, y_close)
558
+ body_bottom = max(y_open, y_close)
559
+ if body_bottom - body_top < 2:
560
+ body_bottom = body_top + 2
561
+ draw.rectangle(
562
+ [(x - candle_width//3, body_top), (x + candle_width//3, body_bottom)],
563
+ fill=color
564
+ )
565
 
566
+ # Save image
567
  filename = f"chart_{i:06d}.png"
568
  img.save(os.path.join(output_dir, "images", filename))
569
  labels[filename] = f"{pattern} trend {vol_name} volatility"
 
 
 
570
 
571
+ # Save labels
572
  with open(os.path.join(output_dir, "labels.json"), "w") as f:
573
  json.dump(labels, f)
574
 
575
+ return f"✅ Generated {int(num_samples)} samples in ./dataset"
576
 
577
  except Exception as e:
578
+ import traceback
579
+ return f"❌ Failed: {str(e)}\n{traceback.format_exc()}"
580
 
581
 
582
  # ============== Gradio Interface ==============