Spaces:
Paused
Paused
Commit
·
366fd1c
1
Parent(s):
27e1ebb
Enable bf16 in load_infinity function and enhance transform function with type hints and error handling; refactor joint_vi_vae_encode_decode for improved performance and error management
Browse files
app.py
CHANGED
|
@@ -102,7 +102,7 @@ def load_infinity(
|
|
| 102 |
text_channels=2048,
|
| 103 |
apply_spatial_patchify=0,
|
| 104 |
use_flex_attn=False,
|
| 105 |
-
bf16=
|
| 106 |
):
|
| 107 |
print('[Loading Infinity]')
|
| 108 |
|
|
@@ -156,45 +156,137 @@ def load_infinity(
|
|
| 156 |
|
| 157 |
# Initialize random number generator on the correct device
|
| 158 |
infinity_test.rng = torch.Generator(device=device)
|
| 159 |
-
|
| 160 |
return infinity_test
|
| 161 |
|
| 162 |
-
def transform(pil_img, tgt_h, tgt_w):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
width, height = pil_img.size
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
#
|
| 172 |
-
arr = np.array(pil_img)
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
def load_visual_tokenizer(args):
|
| 200 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
@@ -219,29 +311,26 @@ def load_visual_tokenizer(args):
|
|
| 219 |
return vae
|
| 220 |
|
| 221 |
def load_transformer(vae, args):
|
|
|
|
| 222 |
model_path = args.model_path
|
| 223 |
-
|
| 224 |
-
|
| 225 |
if osp.exists(args.cache_dir):
|
| 226 |
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
|
| 227 |
else:
|
| 228 |
local_model_path = model_path
|
|
|
|
| 229 |
if args.enable_model_cache:
|
| 230 |
slim_model_path = model_path.replace('ar-', 'slim-')
|
| 231 |
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
|
| 232 |
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
|
| 233 |
-
print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
|
| 234 |
-
print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
|
| 235 |
if not osp.exists(local_slim_model_path):
|
| 236 |
if osp.exists(slim_model_path):
|
| 237 |
-
print(f'copy {slim_model_path} to {local_slim_model_path}')
|
| 238 |
shutil.copyfile(slim_model_path, local_slim_model_path)
|
| 239 |
else:
|
| 240 |
if not osp.exists(local_model_path):
|
| 241 |
-
print(f'copy {model_path} to {local_model_path}')
|
| 242 |
shutil.copyfile(model_path, local_model_path)
|
| 243 |
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
|
| 244 |
-
print(f'copy {local_slim_model_path} to {slim_model_path}')
|
| 245 |
if not osp.exists(slim_model_path):
|
| 246 |
shutil.copyfile(local_slim_model_path, slim_model_path)
|
| 247 |
os.remove(local_model_path)
|
|
@@ -249,33 +338,35 @@ def load_transformer(vae, args):
|
|
| 249 |
slim_model_path = local_slim_model_path
|
| 250 |
else:
|
| 251 |
slim_model_path = model_path
|
| 252 |
-
print(f'
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
| 268 |
infinity = load_infinity(
|
| 269 |
-
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
|
| 270 |
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
|
| 271 |
use_scale_schedule_embedding=args.use_scale_schedule_embedding,
|
| 272 |
pn=args.pn,
|
| 273 |
-
use_bit_label=args.use_bit_label,
|
| 274 |
-
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
|
| 275 |
-
model_path=slim_model_path,
|
| 276 |
-
scale_schedule=None,
|
| 277 |
-
vae=vae,
|
| 278 |
-
device=
|
| 279 |
model_kwargs=kwargs_model,
|
| 280 |
text_channels=args.text_channels,
|
| 281 |
apply_spatial_patchify=args.apply_spatial_patchify,
|
|
|
|
| 102 |
text_channels=2048,
|
| 103 |
apply_spatial_patchify=0,
|
| 104 |
use_flex_attn=False,
|
| 105 |
+
bf16=True,
|
| 106 |
):
|
| 107 |
print('[Loading Infinity]')
|
| 108 |
|
|
|
|
| 156 |
|
| 157 |
# Initialize random number generator on the correct device
|
| 158 |
infinity_test.rng = torch.Generator(device=device)
|
| 159 |
+
|
| 160 |
return infinity_test
|
| 161 |
|
| 162 |
+
def transform(pil_img: PImage.Image, tgt_h: int, tgt_w: int) -> torch.Tensor:
|
| 163 |
+
"""
|
| 164 |
+
Transform a PIL image to a tensor with target dimensions while preserving aspect ratio.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
pil_img: PIL Image to transform
|
| 168 |
+
tgt_h: Target height
|
| 169 |
+
tgt_w: Target width
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
torch.Tensor: Normalized tensor image in range [-1, 1]
|
| 173 |
+
"""
|
| 174 |
+
if not isinstance(pil_img, PImage.Image):
|
| 175 |
+
raise TypeError("Input must be a PIL Image")
|
| 176 |
+
|
| 177 |
+
if tgt_h <= 0 or tgt_w <= 0:
|
| 178 |
+
raise ValueError("Target dimensions must be positive")
|
| 179 |
+
|
| 180 |
+
# Calculate resize dimensions preserving aspect ratio
|
| 181 |
width, height = pil_img.size
|
| 182 |
+
scale = min(tgt_w / width, tgt_h / height)
|
| 183 |
+
new_width = int(width * scale)
|
| 184 |
+
new_height = int(height * scale)
|
| 185 |
+
|
| 186 |
+
# Resize using LANCZOS for best quality
|
| 187 |
+
pil_img = pil_img.resize((new_width, new_height), resample=PImage.LANCZOS)
|
| 188 |
+
|
| 189 |
+
# Create center crop
|
| 190 |
+
arr = np.array(pil_img, dtype=np.uint8)
|
| 191 |
+
|
| 192 |
+
# Calculate crop coordinates
|
| 193 |
+
y1 = max(0, (new_height - tgt_h) // 2)
|
| 194 |
+
x1 = max(0, (new_width - tgt_w) // 2)
|
| 195 |
+
y2 = y1 + tgt_h
|
| 196 |
+
x2 = x1 + tgt_w
|
| 197 |
+
|
| 198 |
+
# Crop and convert to tensor
|
| 199 |
+
arr = arr[y1:y2, x1:x2]
|
| 200 |
+
|
| 201 |
+
# Convert to normalized tensor in one step
|
| 202 |
+
return torch.from_numpy(arr.transpose(2, 0, 1)).float().div_(127.5).sub_(1)
|
| 203 |
+
|
| 204 |
+
def joint_vi_vae_encode_decode(
|
| 205 |
+
vae: 'VAEModel', # Type hint would be more specific with actual VAE class
|
| 206 |
+
image_path: str | Path,
|
| 207 |
+
scale_schedule: List[tuple],
|
| 208 |
+
device: torch.device | str,
|
| 209 |
+
tgt_h: int,
|
| 210 |
+
tgt_w: int
|
| 211 |
+
) -> tuple[np.ndarray, np.ndarray, torch.Tensor]:
|
| 212 |
+
"""
|
| 213 |
+
Encode and decode an image using a VAE model with joint visual-infinity processing.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
vae: The VAE model instance
|
| 217 |
+
image_path: Path to input image
|
| 218 |
+
scale_schedule: List of scale tuples for processing
|
| 219 |
+
device: Target device for computation
|
| 220 |
+
tgt_h: Target height for the image
|
| 221 |
+
tgt_w: Target width for the image
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
tuple containing:
|
| 225 |
+
- Original image as numpy array (uint8)
|
| 226 |
+
- Reconstructed image as numpy array (uint8)
|
| 227 |
+
- Bit indices tensor
|
| 228 |
+
|
| 229 |
+
Raises:
|
| 230 |
+
FileNotFoundError: If image file doesn't exist
|
| 231 |
+
RuntimeError: If VAE processing fails
|
| 232 |
+
"""
|
| 233 |
+
try:
|
| 234 |
+
# Validate input path
|
| 235 |
+
if not Path(image_path).exists():
|
| 236 |
+
raise FileNotFoundError(f"Image not found at {image_path}")
|
| 237 |
+
|
| 238 |
+
# Load and preprocess image
|
| 239 |
+
pil_image = Image.open(image_path).convert('RGB')
|
| 240 |
+
inp = transform(pil_image, tgt_h, tgt_w)
|
| 241 |
+
inp = inp.unsqueeze(0).to(device)
|
| 242 |
+
|
| 243 |
+
# Normalize scale schedule
|
| 244 |
+
scale_schedule = [(s[0], s[1], s[2]) for s in scale_schedule]
|
| 245 |
+
|
| 246 |
+
# Decide whether to use CPU or GPU
|
| 247 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 248 |
+
|
| 249 |
+
# Time the encoding/decoding operations
|
| 250 |
+
with torch.amp.autocast(device, dtype=torch.bfloat16):
|
| 251 |
+
encode_start = time.perf_counter()
|
| 252 |
+
h, z, _, all_bit_indices, _, _ = vae.encode(
|
| 253 |
+
inp,
|
| 254 |
+
scale_schedule=scale_schedule
|
| 255 |
+
)
|
| 256 |
+
encode_time = time.perf_counter() - encode_start
|
| 257 |
+
|
| 258 |
+
decode_start = time.perf_counter()
|
| 259 |
+
recons_img = vae.decode(z)[0]
|
| 260 |
+
decode_time = time.perf_counter() - decode_start
|
| 261 |
+
|
| 262 |
+
# Process reconstruction
|
| 263 |
+
if recons_img.dim() == 4:
|
| 264 |
+
recons_img = recons_img.squeeze(1)
|
| 265 |
+
|
| 266 |
+
# Log performance metrics
|
| 267 |
+
print(f'VAE encode: {encode_time:.2f}s, decode: {decode_time:.2f}s')
|
| 268 |
+
print(f'Reconstruction shape: {recons_img.shape}, z shape: {z.shape}')
|
| 269 |
+
|
| 270 |
+
# Convert to numpy arrays efficiently
|
| 271 |
+
recons_img = (recons_img.add(1).div(2)
|
| 272 |
+
.permute(1, 2, 0)
|
| 273 |
+
.mul(255)
|
| 274 |
+
.cpu()
|
| 275 |
+
.numpy()
|
| 276 |
+
.astype(np.uint8))
|
| 277 |
+
|
| 278 |
+
gt_img = (inp[0].add(1).div(2)
|
| 279 |
+
.permute(1, 2, 0)
|
| 280 |
+
.mul(255)
|
| 281 |
+
.cpu()
|
| 282 |
+
.numpy()
|
| 283 |
+
.astype(np.uint8))
|
| 284 |
+
|
| 285 |
+
return gt_img, recons_img, all_bit_indices
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
print(f"Error in VAE processing: {str(e)}")
|
| 289 |
+
raise RuntimeError("VAE processing failed") from e
|
| 290 |
|
| 291 |
def load_visual_tokenizer(args):
|
| 292 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 311 |
return vae
|
| 312 |
|
| 313 |
def load_transformer(vae, args):
|
| 314 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 315 |
model_path = args.model_path
|
| 316 |
+
|
| 317 |
+
if args.checkpoint_type == 'torch':
|
| 318 |
if osp.exists(args.cache_dir):
|
| 319 |
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
|
| 320 |
else:
|
| 321 |
local_model_path = model_path
|
| 322 |
+
|
| 323 |
if args.enable_model_cache:
|
| 324 |
slim_model_path = model_path.replace('ar-', 'slim-')
|
| 325 |
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
|
| 326 |
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
|
|
|
|
|
|
|
| 327 |
if not osp.exists(local_slim_model_path):
|
| 328 |
if osp.exists(slim_model_path):
|
|
|
|
| 329 |
shutil.copyfile(slim_model_path, local_slim_model_path)
|
| 330 |
else:
|
| 331 |
if not osp.exists(local_model_path):
|
|
|
|
| 332 |
shutil.copyfile(model_path, local_model_path)
|
| 333 |
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
|
|
|
|
| 334 |
if not osp.exists(slim_model_path):
|
| 335 |
shutil.copyfile(local_slim_model_path, slim_model_path)
|
| 336 |
os.remove(local_model_path)
|
|
|
|
| 338 |
slim_model_path = local_slim_model_path
|
| 339 |
else:
|
| 340 |
slim_model_path = model_path
|
| 341 |
+
print(f'Loading checkpoint from {slim_model_path}')
|
| 342 |
+
else:
|
| 343 |
+
raise ValueError(f"Unsupported checkpoint_type: {args.checkpoint_type}")
|
| 344 |
+
|
| 345 |
+
model_configs = {
|
| 346 |
+
'infinity_2b': dict(depth=32, embed_dim=2048, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8),
|
| 347 |
+
'infinity_layer12': dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
| 348 |
+
'infinity_layer16': dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
| 349 |
+
'infinity_layer24': dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
| 350 |
+
'infinity_layer32': dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
| 351 |
+
'infinity_layer40': dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
| 352 |
+
'infinity_layer48': dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
kwargs_model = model_configs.get(args.model_type)
|
| 356 |
+
if kwargs_model is None:
|
| 357 |
+
raise ValueError(f"Unsupported model_type: {args.model_type}")
|
| 358 |
+
|
| 359 |
infinity = load_infinity(
|
| 360 |
+
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
|
| 361 |
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
|
| 362 |
use_scale_schedule_embedding=args.use_scale_schedule_embedding,
|
| 363 |
pn=args.pn,
|
| 364 |
+
use_bit_label=args.use_bit_label,
|
| 365 |
+
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
|
| 366 |
+
model_path=slim_model_path,
|
| 367 |
+
scale_schedule=None,
|
| 368 |
+
vae=vae,
|
| 369 |
+
device=device,
|
| 370 |
model_kwargs=kwargs_model,
|
| 371 |
text_channels=args.text_channels,
|
| 372 |
apply_spatial_patchify=args.apply_spatial_patchify,
|