Enzo8930302 commited on
Commit
165a196
·
verified ·
1 Parent(s): d3516f4

Upload quick_start.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. quick_start.py +188 -90
quick_start.py CHANGED
@@ -1,124 +1,222 @@
1
  """
2
- Quick Start Script
3
- Setup and test Byte Dream in one command
4
  """
5
 
6
- import subprocess
7
  import sys
8
  from pathlib import Path
9
 
10
 
11
- def check_requirements():
12
- """Check if requirements are installed"""
13
- print("Checking requirements...")
14
-
15
- required = [
16
- 'torch',
17
- 'transformers',
18
- 'diffusers',
19
- 'pillow',
20
- 'numpy',
21
- 'gradio',
22
- ]
23
-
24
- missing = []
25
-
26
- for package in required:
27
- try:
28
- __import__(package.replace('-', '_'))
29
- print(f" ✓ {package}")
30
- except ImportError:
31
- print(f" ✗ {package} - MISSING")
32
- missing.append(package)
33
-
34
- if missing:
35
- print(f"\nMissing packages: {', '.join(missing)}")
36
- print("\nInstall with:")
37
- print(" pip install -r requirements.txt")
38
  return False
39
 
40
- print("\n✓ All requirements satisfied!")
41
  return True
42
 
43
 
44
- def test_model():
45
- """Test model generation"""
46
- print("\n" + "="*60)
47
- print("Testing Byte Dream Model")
48
- print("="*60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  try:
51
  from bytedream.generator import ByteDreamGenerator
52
 
53
- print("\nInitializing generator...")
54
- generator = ByteDreamGenerator(device="cpu")
 
 
 
 
 
55
 
56
- print("\nModel info:")
57
- info = generator.get_model_info()
58
- for key, value in info.items():
59
- print(f" {key}: {value}")
 
 
 
60
 
61
- print("\nGenerating test image...")
62
- print("Prompt: A simple test pattern, geometric shapes")
 
 
 
 
63
 
64
- image = generator.generate(
65
- prompt="A simple test pattern, geometric shapes, abstract art",
66
- width=256,
67
- height=256,
68
- num_inference_steps=20,
69
- seed=42,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
 
72
- output_path = Path("test_output.png")
73
- image.save(output_path)
74
 
75
- print(f"\n✓ Test successful!")
76
- print(f" Image saved to: {output_path.absolute()}")
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- return True
 
 
 
79
 
80
  except Exception as e:
81
- print(f"\n Error: {e}")
82
- print("\nThe model needs to be trained or pretrained weights downloaded.")
83
- return False
84
 
85
 
86
- def download_pretrained():
87
- """Download pretrained model from Hugging Face"""
88
- print("\n" + "="*60)
89
- print("Downloading Pretrained Model")
90
- print("="*60)
91
-
92
- print("\nTo download a pretrained model:")
93
- print("1. Visit https://huggingface.co/models")
94
- print("2. Search for 'stable-diffusion' or similar")
95
- print("3. Download using:")
96
- print("\n from huggingface_hub import snapshot_download")
97
- print(" snapshot_download(repo_id='username/model', local_dir='./models/bytedream')")
98
- print("\nOr train your own model with:")
99
- print(" python train.py --train_data ./dataset --output_dir ./models/bytedream")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
 
102
  def main():
103
- print("="*60)
104
- print("Byte Dream - Quick Start")
105
- print("="*60)
106
-
107
- # Check requirements
108
- if not check_requirements():
109
- print("\n⚠ Please install requirements first")
110
- sys.exit(1)
111
-
112
- # Test model
113
- if test_model():
114
- print("\n✓ Byte Dream is ready to use!")
115
- print("\nNext steps:")
116
- print(" - Run: python infer.py --prompt 'Your prompt here'")
117
- print(" - Run: python app.py (for web interface)")
118
- print(" - Run: python main.py --demo (for demo)")
 
 
 
 
 
119
  else:
120
- download_pretrained()
 
 
 
121
 
122
 
123
  if __name__ == "__main__":
124
- main()
 
 
 
 
 
1
  """
2
+ Quick Start Script for Hugging Face Integration
3
+ Helps you upload/download models easily
4
  """
5
 
 
6
  import sys
7
  from pathlib import Path
8
 
9
 
10
+ def print_banner(text):
11
+ """Print formatted banner"""
12
+ print("\n" + "="*60)
13
+ print(text.center(60))
14
+ print("="*60 + "\n")
15
+
16
+
17
+ def check_model_exists():
18
+ """Check if trained model exists"""
19
+ model_path = Path("./models/bytedream")
20
+
21
+ if not model_path.exists():
22
+ print("❌ Model directory not found!")
23
+ print("\nPlease train the model first:")
24
+ print(" python train.py")
25
+ print("\nOr download from Hugging Face:")
26
+ print(" python infer.py --hf_repo username/repo --prompt 'test'")
27
+ return False
28
+
29
+ # Check for weights
30
+ unet_weights = model_path / "unet" / "pytorch_model.bin"
31
+ vae_weights = model_path / "vae" / "pytorch_model.bin"
32
+
33
+ if not (unet_weights.exists() or (model_path / "pytorch_model.bin").exists()):
34
+ print(" Model directory exists but no weights found!")
35
+ print("Please train the model first.")
 
36
  return False
37
 
 
38
  return True
39
 
40
 
41
+ def upload_to_hf():
42
+ """Upload model to Hugging Face"""
43
+ print_banner("UPLOAD TO HUGGING FACE HUB")
44
+
45
+ # Check model exists
46
+ if not check_model_exists():
47
+ return
48
+
49
+ # Get token
50
+ token = input("Enter your Hugging Face token (hf_...): ").strip()
51
+ if not token:
52
+ print("❌ Token is required!")
53
+ return
54
+
55
+ # Get repo ID
56
+ repo_id = input("Enter repository ID (e.g., username/ByteDream): ").strip()
57
+ if not repo_id:
58
+ print("❌ Repository ID is required!")
59
+ return
60
+
61
+ print(f"\n📤 Uploading to {repo_id}...")
62
 
63
  try:
64
  from bytedream.generator import ByteDreamGenerator
65
 
66
+ # Load model
67
+ print("\nLoading model...")
68
+ generator = ByteDreamGenerator(
69
+ model_path="./models/bytedream",
70
+ config_path="config.yaml",
71
+ device="cpu",
72
+ )
73
 
74
+ # Upload
75
+ generator.push_to_hub(
76
+ repo_id=repo_id,
77
+ token=token,
78
+ private=False,
79
+ commit_message="Upload Byte Dream model",
80
+ )
81
 
82
+ print("\n✅ SUCCESS!")
83
+ print(f"\n📦 Your model is available at:")
84
+ print(f"https://huggingface.co/{repo_id}")
85
+ print(f"\nTo use this model:")
86
+ print(f" python infer.py --prompt 'your prompt' --hf_repo '{repo_id}'")
87
+ print("="*60)
88
 
89
+ except Exception as e:
90
+ print(f"\n❌ Error: {e}")
91
+ import traceback
92
+ traceback.print_exc()
93
+
94
+
95
+ def download_from_hf():
96
+ """Download model from Hugging Face"""
97
+ print_banner("DOWNLOAD FROM HUGGING FACE HUB")
98
+
99
+ # Get repo ID
100
+ repo_id = input("Enter repository ID (e.g., username/ByteDream): ").strip()
101
+ if not repo_id:
102
+ print("❌ Repository ID is required!")
103
+ return
104
+
105
+ print(f"\n📥 Downloading from {repo_id}...")
106
+
107
+ try:
108
+ from bytedream.generator import ByteDreamGenerator
109
+
110
+ # Load from HF
111
+ generator = ByteDreamGenerator(
112
+ hf_repo_id=repo_id,
113
+ config_path="config.yaml",
114
+ device="cpu",
115
  )
116
 
117
+ print("\n✅ Model loaded successfully!")
 
118
 
119
+ # Test generation
120
+ test = input("\nGenerate test image? (y/n): ").strip().lower()
121
+ if test == 'y':
122
+ print("\nGenerating test image...")
123
+ image = generator.generate(
124
+ prompt="test pattern, simple colors",
125
+ width=256,
126
+ height=256,
127
+ num_inference_steps=10,
128
+ )
129
+
130
+ output = "test_output.png"
131
+ image.save(output)
132
+ print(f"✓ Test image saved to: {output}")
133
 
134
+ print("\nTo generate images:")
135
+ print(f" python infer.py --prompt 'your prompt' --hf_repo '{repo_id}'")
136
+ print(f" HF_REPO_ID={repo_id} python app.py")
137
+ print("="*60)
138
 
139
  except Exception as e:
140
+ print(f"\n Error: {e}")
141
+ import traceback
142
+ traceback.print_exc()
143
 
144
 
145
+ def test_local_model():
146
+ """Test local model"""
147
+ print_banner("TEST LOCAL MODEL")
148
+
149
+ if not check_model_exists():
150
+ return
151
+
152
+ print("Loading local model...")
153
+
154
+ try:
155
+ from bytedream.generator import ByteDreamGenerator
156
+
157
+ generator = ByteDreamGenerator(
158
+ model_path="./models/bytedream",
159
+ config_path="config.yaml",
160
+ device="cpu",
161
+ )
162
+
163
+ print("\n✅ Model loaded successfully!")
164
+
165
+ # Generate test image
166
+ print("\nGenerating test image...")
167
+ image = generator.generate(
168
+ prompt="test pattern, simple colors",
169
+ width=256,
170
+ height=256,
171
+ num_inference_steps=10,
172
+ )
173
+
174
+ output = "test_output.png"
175
+ image.save(output)
176
+ print(f"✓ Test image saved to: {output}")
177
+
178
+ print("\nModel ready for upload!")
179
+ print("To upload: python quick_start.py upload")
180
+ print("="*60)
181
+
182
+ except Exception as e:
183
+ print(f"\n❌ Error: {e}")
184
+ import traceback
185
+ traceback.print_exc()
186
 
187
 
188
  def main():
189
+ """Main function"""
190
+ print_banner("BYTE DREAM - QUICK START")
191
+
192
+ print("What would you like to do?")
193
+ print("1. Upload model to Hugging Face")
194
+ print("2. Download model from Hugging Face")
195
+ print("3. Test local model")
196
+ print("4. Exit")
197
+ print()
198
+
199
+ choice = input("Enter choice (1-4): ").strip()
200
+
201
+ if choice == "1":
202
+ upload_to_hf()
203
+ elif choice == "2":
204
+ download_from_hf()
205
+ elif choice == "3":
206
+ test_local_model()
207
+ elif choice == "4":
208
+ print("\nGoodbye!")
209
+ return
210
  else:
211
+ print("❌ Invalid choice!")
212
+ return
213
+
214
+ print("\nDone!")
215
 
216
 
217
  if __name__ == "__main__":
218
+ try:
219
+ main()
220
+ except KeyboardInterrupt:
221
+ print("\n\nInterrupted!")
222
+ sys.exit(0)