Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -390,6 +390,42 @@ DEVICE = None
|
|
| 390 |
CONFIG = None
|
| 391 |
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
def load_model(checkpoint_path=None):
|
| 394 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 395 |
|
|
@@ -686,6 +722,21 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 686 |
logs.append("-" * 40)
|
| 687 |
logs.append(f"✅ Model saved to {save_path}")
|
| 688 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
return "\n".join(logs)
|
| 690 |
|
| 691 |
except Exception as e:
|
|
|
|
| 390 |
CONFIG = None
|
| 391 |
|
| 392 |
|
| 393 |
+
def save_to_hub(save_name, repo_id=None):
|
| 394 |
+
"""Save model checkpoint to HuggingFace Hub for persistence."""
|
| 395 |
+
global MODEL, TEXT_ENCODER, CONFIG
|
| 396 |
+
|
| 397 |
+
if MODEL is None:
|
| 398 |
+
return "❌ No model loaded to save"
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
from huggingface_hub import HfApi, upload_file
|
| 402 |
+
import tempfile
|
| 403 |
+
|
| 404 |
+
# Save to temp file
|
| 405 |
+
with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f:
|
| 406 |
+
torch.save({
|
| 407 |
+
"model_state_dict": MODEL.state_dict(),
|
| 408 |
+
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 409 |
+
"config": CONFIG
|
| 410 |
+
}, f.name)
|
| 411 |
+
temp_path = f.name
|
| 412 |
+
|
| 413 |
+
# Upload to Hub (same Space repo)
|
| 414 |
+
api = HfApi()
|
| 415 |
+
api.upload_file(
|
| 416 |
+
path_or_fileobj=temp_path,
|
| 417 |
+
path_in_repo=f"checkpoints/{save_name}.pt",
|
| 418 |
+
repo_id=repo_id or "Spanicin/candlestick-diffusion",
|
| 419 |
+
repo_type="space"
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
os.unlink(temp_path)
|
| 423 |
+
return f"✅ Model saved to Hub: checkpoints/{save_name}.pt"
|
| 424 |
+
|
| 425 |
+
except Exception as e:
|
| 426 |
+
return f"❌ Failed to save to Hub: {str(e)}"
|
| 427 |
+
|
| 428 |
+
|
| 429 |
def load_model(checkpoint_path=None):
|
| 430 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 431 |
|
|
|
|
| 722 |
logs.append("-" * 40)
|
| 723 |
logs.append(f"✅ Model saved to {save_path}")
|
| 724 |
|
| 725 |
+
# Also try to save to Hub for persistence
|
| 726 |
+
try:
|
| 727 |
+
from huggingface_hub import HfApi
|
| 728 |
+
api = HfApi()
|
| 729 |
+
api.upload_file(
|
| 730 |
+
path_or_fileobj=save_path,
|
| 731 |
+
path_in_repo=f"checkpoints/{save_name}.pt",
|
| 732 |
+
repo_id="Spanicin/candlestick-diffusion",
|
| 733 |
+
repo_type="space"
|
| 734 |
+
)
|
| 735 |
+
logs.append(f"☁️ Model uploaded to Hub (persistent)")
|
| 736 |
+
except Exception as hub_error:
|
| 737 |
+
logs.append(f"⚠️ Could not upload to Hub: {hub_error}")
|
| 738 |
+
logs.append(" Model saved locally but may be lost on restart")
|
| 739 |
+
|
| 740 |
return "\n".join(logs)
|
| 741 |
|
| 742 |
except Exception as e:
|