Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- app.py +4 -8
- requirements.txt +1 -2
app.py
CHANGED
|
@@ -438,7 +438,7 @@ def load_model(checkpoint_path=None):
|
|
| 438 |
|
| 439 |
# ============== Gradio Interface ==============
|
| 440 |
|
| 441 |
-
def generate_chart(prompt, num_steps, guidance_scale, seed
|
| 442 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 443 |
|
| 444 |
if MODEL is None:
|
|
@@ -453,9 +453,6 @@ def generate_chart(prompt, num_steps, guidance_scale, seed, progress=gr.Progress
|
|
| 453 |
if DEVICE.type == "cuda":
|
| 454 |
torch.cuda.manual_seed(seed)
|
| 455 |
|
| 456 |
-
def update_progress(p):
|
| 457 |
-
progress(p, desc="Generating...")
|
| 458 |
-
|
| 459 |
with torch.no_grad():
|
| 460 |
context = TEXT_ENCODER([prompt], DEVICE)
|
| 461 |
context_uncond = TEXT_ENCODER.get_uncond(1, DEVICE)
|
|
@@ -465,7 +462,7 @@ def generate_chart(prompt, num_steps, guidance_scale, seed, progress=gr.Progress
|
|
| 465 |
shape=(1, 3, CONFIG["image_size"], CONFIG["image_size"]),
|
| 466 |
steps=num_steps,
|
| 467 |
guidance_scale=guidance_scale,
|
| 468 |
-
progress_callback=
|
| 469 |
)
|
| 470 |
|
| 471 |
# Convert to image
|
|
@@ -481,7 +478,7 @@ def generate_chart(prompt, num_steps, guidance_scale, seed, progress=gr.Progress
|
|
| 481 |
return None, f"❌ Error: {str(e)}"
|
| 482 |
|
| 483 |
|
| 484 |
-
def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_name
|
| 485 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 486 |
|
| 487 |
try:
|
|
@@ -545,7 +542,6 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 545 |
|
| 546 |
epoch_loss += loss.item()
|
| 547 |
current_step += 1
|
| 548 |
-
progress(current_step / total_steps, desc=f"Epoch {epoch+1}/{epochs}")
|
| 549 |
|
| 550 |
avg_loss = epoch_loss / len(train_loader)
|
| 551 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
|
@@ -583,7 +579,7 @@ def load_checkpoint(checkpoint_file):
|
|
| 583 |
# ============== Gradio UI ==============
|
| 584 |
|
| 585 |
def create_demo():
|
| 586 |
-
with gr.Blocks(title="Candlestick Chart Generator"
|
| 587 |
gr.Markdown("""
|
| 588 |
# 📈 Candlestick Chart Diffusion Generator
|
| 589 |
|
|
|
|
| 438 |
|
| 439 |
# ============== Gradio Interface ==============
|
| 440 |
|
| 441 |
+
def generate_chart(prompt, num_steps, guidance_scale, seed):
|
| 442 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 443 |
|
| 444 |
if MODEL is None:
|
|
|
|
| 453 |
if DEVICE.type == "cuda":
|
| 454 |
torch.cuda.manual_seed(seed)
|
| 455 |
|
|
|
|
|
|
|
|
|
|
| 456 |
with torch.no_grad():
|
| 457 |
context = TEXT_ENCODER([prompt], DEVICE)
|
| 458 |
context_uncond = TEXT_ENCODER.get_uncond(1, DEVICE)
|
|
|
|
| 462 |
shape=(1, 3, CONFIG["image_size"], CONFIG["image_size"]),
|
| 463 |
steps=num_steps,
|
| 464 |
guidance_scale=guidance_scale,
|
| 465 |
+
progress_callback=None
|
| 466 |
)
|
| 467 |
|
| 468 |
# Convert to image
|
|
|
|
| 478 |
return None, f"❌ Error: {str(e)}"
|
| 479 |
|
| 480 |
|
| 481 |
+
def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_name):
|
| 482 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 483 |
|
| 484 |
try:
|
|
|
|
| 542 |
|
| 543 |
epoch_loss += loss.item()
|
| 544 |
current_step += 1
|
|
|
|
| 545 |
|
| 546 |
avg_loss = epoch_loss / len(train_loader)
|
| 547 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
|
|
|
| 579 |
# ============== Gradio UI ==============
|
| 580 |
|
| 581 |
def create_demo():
|
| 582 |
+
with gr.Blocks(title="Candlestick Chart Generator") as demo:
|
| 583 |
gr.Markdown("""
|
| 584 |
# 📈 Candlestick Chart Diffusion Generator
|
| 585 |
|
requirements.txt
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
torch>=2.0.0
|
| 2 |
torchvision>=0.15.0
|
| 3 |
-
gradio==
|
| 4 |
-
huggingface_hub>=0.22.0
|
| 5 |
Pillow>=9.5.0
|
| 6 |
numpy>=1.24.0
|
| 7 |
einops>=0.6.1
|
|
|
|
| 1 |
torch>=2.0.0
|
| 2 |
torchvision>=0.15.0
|
| 3 |
+
gradio==3.50.2
|
|
|
|
| 4 |
Pillow>=9.5.0
|
| 5 |
numpy>=1.24.0
|
| 6 |
einops>=0.6.1
|