Spanicin commited on
Commit
f64e771
·
verified ·
1 Parent(s): 476d0fb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py CHANGED
@@ -390,6 +390,42 @@ DEVICE = None
390
  CONFIG = None
391
 
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  def load_model(checkpoint_path=None):
394
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
395
 
@@ -686,6 +722,21 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
686
  logs.append("-" * 40)
687
  logs.append(f"✅ Model saved to {save_path}")
688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
  return "\n".join(logs)
690
 
691
  except Exception as e:
 
390
  CONFIG = None
391
 
392
 
393
+ def save_to_hub(save_name, repo_id=None):
394
+ """Save model checkpoint to HuggingFace Hub for persistence."""
395
+ global MODEL, TEXT_ENCODER, CONFIG
396
+
397
+ if MODEL is None:
398
+ return "❌ No model loaded to save"
399
+
400
+ try:
401
+ from huggingface_hub import HfApi, upload_file
402
+ import tempfile
403
+
404
+ # Save to temp file
405
+ with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f:
406
+ torch.save({
407
+ "model_state_dict": MODEL.state_dict(),
408
+ "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
409
+ "config": CONFIG
410
+ }, f.name)
411
+ temp_path = f.name
412
+
413
+ # Upload to Hub (same Space repo)
414
+ api = HfApi()
415
+ api.upload_file(
416
+ path_or_fileobj=temp_path,
417
+ path_in_repo=f"checkpoints/{save_name}.pt",
418
+ repo_id=repo_id or "Spanicin/candlestick-diffusion",
419
+ repo_type="space"
420
+ )
421
+
422
+ os.unlink(temp_path)
423
+ return f"✅ Model saved to Hub: checkpoints/{save_name}.pt"
424
+
425
+ except Exception as e:
426
+ return f"❌ Failed to save to Hub: {str(e)}"
427
+
428
+
429
  def load_model(checkpoint_path=None):
430
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
431
 
 
722
  logs.append("-" * 40)
723
  logs.append(f"✅ Model saved to {save_path}")
724
 
725
+ # Also try to save to Hub for persistence
726
+ try:
727
+ from huggingface_hub import HfApi
728
+ api = HfApi()
729
+ api.upload_file(
730
+ path_or_fileobj=save_path,
731
+ path_in_repo=f"checkpoints/{save_name}.pt",
732
+ repo_id="Spanicin/candlestick-diffusion",
733
+ repo_type="space"
734
+ )
735
+ logs.append(f"☁️ Model uploaded to Hub (persistent)")
736
+ except Exception as hub_error:
737
+ logs.append(f"⚠️ Could not upload to Hub: {hub_error}")
738
+ logs.append(" Model saved locally but may be lost on restart")
739
+
740
  return "\n".join(logs)
741
 
742
  except Exception as e: