Enzo8930302 commited on
Commit
689eaa0
·
verified ·
1 Parent(s): 695a745

Upload bytedream/hf_api.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. bytedream/hf_api.py +346 -0
bytedream/hf_api.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face API Client for Byte Dream
3
+ Use Byte Dream models directly from Hugging Face Hub
4
+ """
5
+
6
+ import torch
7
+ import requests
8
+ import base64
9
+ from io import BytesIO
10
+ from PIL import Image
11
+ from typing import Optional, List, Union
12
+ import time
13
+
14
+
15
+ class HuggingFaceAPI:
16
+ """
17
+ Client for Hugging Face Inference API
18
+ Allows using Byte Dream models without downloading them
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ repo_id: str,
24
+ token: Optional[str] = None,
25
+ use_gpu: bool = False,
26
+ ):
27
+ """
28
+ Initialize Hugging Face API client
29
+
30
+ Args:
31
+ repo_id: Repository ID (e.g., "username/ByteDream")
32
+ token: Hugging Face API token (optional but recommended)
33
+ use_gpu: Request GPU inference (if available)
34
+ """
35
+ self.repo_id = repo_id
36
+ self.token = token
37
+ self.use_gpu = use_gpu
38
+
39
+ # API endpoints
40
+ self.inference_api_url = f"https://api-inference.huggingface.co/models/{repo_id}"
41
+ self.headers = {}
42
+
43
+ if token:
44
+ self.headers["Authorization"] = f"Bearer {token}"
45
+
46
+ print(f"✓ Hugging Face API initialized for: {repo_id}")
47
+
48
+ def query(
49
+ self,
50
+ prompt: str,
51
+ negative_prompt: str = "",
52
+ width: int = 512,
53
+ height: int = 512,
54
+ num_inference_steps: int = 50,
55
+ guidance_scale: float = 7.5,
56
+ seed: Optional[int] = None,
57
+ ) -> Image.Image:
58
+ """
59
+ Query the model using Inference API
60
+
61
+ Args:
62
+ prompt: Text prompt
63
+ negative_prompt: Negative prompt
64
+ width: Image width
65
+ height: Image height
66
+ num_inference_steps: Number of denoising steps
67
+ guidance_scale: Guidance scale
68
+ seed: Random seed
69
+
70
+ Returns:
71
+ Generated PIL Image
72
+ """
73
+ payload = {
74
+ "inputs": prompt,
75
+ "parameters": {
76
+ "negative_prompt": negative_prompt,
77
+ "width": width,
78
+ "height": height,
79
+ "num_inference_steps": num_inference_steps,
80
+ "guidance_scale": guidance_scale,
81
+ }
82
+ }
83
+
84
+ if seed is not None:
85
+ payload["parameters"]["seed"] = seed
86
+
87
+ # Make request
88
+ response = requests.post(
89
+ self.inference_api_url,
90
+ headers=self.headers,
91
+ json=payload,
92
+ )
93
+
94
+ # Handle errors
95
+ if response.status_code == 503:
96
+ # Model is loading
97
+ print("Model is loading on HF servers. Waiting...")
98
+ time.sleep(5)
99
+ return self.query(prompt, negative_prompt, width, height,
100
+ num_inference_steps, guidance_scale, seed)
101
+
102
+ response.raise_for_status()
103
+
104
+ # Parse image
105
+ image_bytes = response.content
106
+ image = Image.open(BytesIO(image_bytes))
107
+
108
+ return image
109
+
110
+ def query_batch(
111
+ self,
112
+ prompts: List[str],
113
+ negative_prompt: str = "",
114
+ width: int = 512,
115
+ height: int = 512,
116
+ num_inference_steps: int = 50,
117
+ guidance_scale: float = 7.5,
118
+ seeds: Optional[List[int]] = None,
119
+ ) -> List[Image.Image]:
120
+ """
121
+ Generate multiple images
122
+
123
+ Args:
124
+ prompts: List of prompts
125
+ negative_prompt: Negative prompt
126
+ width: Image width
127
+ height: Image height
128
+ num_inference_steps: Number of steps
129
+ guidance_scale: Guidance scale
130
+ seeds: List of seeds
131
+
132
+ Returns:
133
+ List of PIL Images
134
+ """
135
+ images = []
136
+
137
+ for i, prompt in enumerate(prompts):
138
+ seed = seeds[i] if seeds and i < len(seeds) else None
139
+
140
+ print(f"Generating image {i+1}/{len(prompts)}...")
141
+ image = self.query(
142
+ prompt=prompt,
143
+ negative_prompt=negative_prompt,
144
+ width=width,
145
+ height=height,
146
+ num_inference_steps=num_inference_steps,
147
+ guidance_scale=guidance_scale,
148
+ seed=seed,
149
+ )
150
+
151
+ images.append(image)
152
+
153
+ return images
154
+
155
+
156
+ class ByteDreamHFClient:
157
+ """
158
+ High-level client for Byte Dream on Hugging Face
159
+ Supports both local inference and API usage
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ repo_id: str,
165
+ token: Optional[str] = None,
166
+ use_api: bool = False,
167
+ device: str = "cpu",
168
+ ):
169
+ """
170
+ Initialize Byte Dream HF client
171
+
172
+ Args:
173
+ repo_id: Repository ID on Hugging Face
174
+ token: HF API token
175
+ use_api: Use Inference API instead of local inference
176
+ device: Device for local inference
177
+ """
178
+ self.repo_id = repo_id
179
+ self.token = token
180
+ self.use_api = use_api
181
+ self.device = device
182
+
183
+ if use_api:
184
+ self.api_client = HuggingFaceAPI(repo_id, token)
185
+ print("✓ Using Hugging Face Inference API")
186
+ else:
187
+ # Load model locally
188
+ from bytedream.generator import ByteDreamGenerator
189
+ self.generator = ByteDreamGenerator(
190
+ hf_repo_id=repo_id,
191
+ config_path="config.yaml",
192
+ device=device,
193
+ )
194
+ print("✓ Model loaded locally from Hugging Face")
195
+
196
+ def generate(
197
+ self,
198
+ prompt: str,
199
+ negative_prompt: str = "",
200
+ width: int = 512,
201
+ height: int = 512,
202
+ num_inference_steps: int = 50,
203
+ guidance_scale: float = 7.5,
204
+ seed: Optional[int] = None,
205
+ ) -> Image.Image:
206
+ """
207
+ Generate image from prompt
208
+
209
+ Args:
210
+ prompt: Text description
211
+ negative_prompt: Things to avoid
212
+ width: Image width
213
+ height: Image height
214
+ num_inference_steps: Number of steps
215
+ guidance_scale: Guidance scale
216
+ seed: Random seed
217
+
218
+ Returns:
219
+ Generated PIL Image
220
+ """
221
+ if self.use_api:
222
+ return self.api_client.query(
223
+ prompt=prompt,
224
+ negative_prompt=negative_prompt,
225
+ width=width,
226
+ height=height,
227
+ num_inference_steps=num_inference_steps,
228
+ guidance_scale=guidance_scale,
229
+ seed=seed,
230
+ )
231
+ else:
232
+ return self.generator.generate(
233
+ prompt=prompt,
234
+ negative_prompt=negative_prompt if negative_prompt else None,
235
+ width=width,
236
+ height=height,
237
+ num_inference_steps=num_inference_steps,
238
+ guidance_scale=guidance_scale,
239
+ seed=seed,
240
+ )
241
+
242
+ def generate_batch(
243
+ self,
244
+ prompts: List[str],
245
+ negative_prompt: str = "",
246
+ width: int = 512,
247
+ height: int = 512,
248
+ num_inference_steps: int = 50,
249
+ guidance_scale: float = 7.5,
250
+ seeds: Optional[List[int]] = None,
251
+ ) -> List[Image.Image]:
252
+ """
253
+ Generate multiple images
254
+
255
+ Args:
256
+ prompts: List of text descriptions
257
+ negative_prompt: Things to avoid
258
+ width: Image width
259
+ height: Image height
260
+ num_inference_steps: Number of steps
261
+ guidance_scale: Guidance scale
262
+ seeds: List of random seeds
263
+
264
+ Returns:
265
+ List of PIL Images
266
+ """
267
+ if self.use_api:
268
+ return self.api_client.query_batch(
269
+ prompts=prompts,
270
+ negative_prompt=negative_prompt,
271
+ width=width,
272
+ height=height,
273
+ num_inference_steps=num_inference_steps,
274
+ guidance_scale=guidance_scale,
275
+ seeds=seeds,
276
+ )
277
+ else:
278
+ return self.generator.generate_batch(
279
+ prompts=prompts,
280
+ negative_prompt=negative_prompt if negative_prompt else None,
281
+ width=width,
282
+ height=height,
283
+ num_inference_steps=num_inference_steps,
284
+ guidance_scale=guidance_scale,
285
+ seeds=seeds,
286
+ )
287
+
288
+
289
+ # Example usage
290
+ if __name__ == "__main__":
291
+ # Example 1: Use Inference API
292
+ print("=" * 60)
293
+ print("Example 1: Using Hugging Face Inference API")
294
+ print("=" * 60)
295
+
296
+ # You need a token for private models or higher rate limits
297
+ # token = "hf_xxxxxxxxxxxxx"
298
+
299
+ try:
300
+ client = ByteDreamHFClient(
301
+ repo_id="Enzo8930302/ByteDream", # Replace with your repo
302
+ # token=token, # Optional but recommended
303
+ use_api=True, # Set True to use API
304
+ )
305
+
306
+ image = client.generate(
307
+ prompt="A beautiful sunset over mountains, digital art",
308
+ negative_prompt="ugly, blurry, low quality",
309
+ width=512,
310
+ height=512,
311
+ num_inference_steps=50,
312
+ guidance_scale=7.5,
313
+ seed=42,
314
+ )
315
+
316
+ image.save("output_api.png")
317
+ print("✓ Image saved to output_api.png")
318
+
319
+ except Exception as e:
320
+ print(f"Error: {e}")
321
+ print("Make sure the model exists on Hugging Face")
322
+
323
+ # Example 2: Download and run locally
324
+ print("\n" + "=" * 60)
325
+ print("Example 2: Download and run locally on CPU")
326
+ print("=" * 60)
327
+
328
+ try:
329
+ client_local = ByteDreamHFClient(
330
+ repo_id="Enzo8930302/ByteDream",
331
+ use_api=False, # Download and run locally
332
+ device="cpu",
333
+ )
334
+
335
+ image_local = client_local.generate(
336
+ prompt="A futuristic city at night, cyberpunk style",
337
+ width=512,
338
+ height=512,
339
+ num_inference_steps=30,
340
+ )
341
+
342
+ image_local.save("output_local.png")
343
+ print("✓ Image saved to output_local.png")
344
+
345
+ except Exception as e:
346
+ print(f"Error: {e}")