Upload app.py
Browse files
app.py
CHANGED
|
@@ -206,6 +206,7 @@ def load_models():
|
|
| 206 |
print("Loading CLIP π")
|
| 207 |
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
|
| 208 |
clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
|
|
|
|
| 209 |
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
| 210 |
print("Loading VLM's custom vision model π")
|
| 211 |
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu', weights_only=False)
|
|
@@ -312,88 +313,88 @@ def stream_chat(input_images: List[Image.Image], caption_type: str, caption_leng
|
|
| 312 |
|
| 313 |
for i in range(0, len(input_images), batch_size):
|
| 314 |
batch = input_images[i:i+batch_size]
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
| 322 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
| 323 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
caption =
|
| 396 |
-
all_captions.append(caption)
|
| 397 |
|
| 398 |
if pbar:
|
| 399 |
pbar.update(len(batch))
|
|
|
|
| 206 |
print("Loading CLIP π")
|
| 207 |
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
|
| 208 |
clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
|
| 209 |
+
assert (CHECKPOINT_PATH / "clip_model.pt").exists()
|
| 210 |
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
| 211 |
print("Loading VLM's custom vision model π")
|
| 212 |
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu', weights_only=False)
|
|
|
|
| 313 |
|
| 314 |
for i in range(0, len(input_images), batch_size):
|
| 315 |
batch = input_images[i:i+batch_size]
|
| 316 |
+
|
| 317 |
+
for input_image in input_images:
|
| 318 |
+
try:
|
| 319 |
+
# Preprocess image
|
| 320 |
+
# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
|
| 321 |
+
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
| 322 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
| 323 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
| 324 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
| 325 |
+
pixel_values = pixel_values.to(device)
|
| 326 |
+
except ValueError as e:
|
| 327 |
+
print(f"Error processing image: {e}")
|
| 328 |
+
print("Skipping this image and continuing...")
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
# Embed image
|
| 332 |
+
# This results in Batch x Image Tokens x Features
|
| 333 |
+
with torch.amp.autocast_mode.autocast(device, enabled=True):
|
| 334 |
+
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
| 335 |
+
image_features = vision_outputs.hidden_states
|
| 336 |
+
embedded_images = image_adapter(image_features).to(device)
|
| 337 |
+
|
| 338 |
+
# Build the conversation
|
| 339 |
+
convo = [
|
| 340 |
+
{
|
| 341 |
+
"role": "system",
|
| 342 |
+
"content": "You are a helpful image captioner.",
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"role": "user",
|
| 346 |
+
"content": prompt_str,
|
| 347 |
+
},
|
| 348 |
+
]
|
| 349 |
+
|
| 350 |
+
# Format the conversation
|
| 351 |
+
convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
|
| 352 |
+
assert isinstance(convo_string, str)
|
| 353 |
+
|
| 354 |
+
# Tokenize the conversation
|
| 355 |
+
# prompt_str is tokenized separately so we can do the calculations below
|
| 356 |
+
convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
|
| 357 |
+
prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
|
| 358 |
+
assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
|
| 359 |
+
convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
|
| 360 |
+
prompt_tokens = prompt_tokens.squeeze(0)
|
| 361 |
+
|
| 362 |
+
# Calculate where to inject the image
|
| 363 |
+
eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
|
| 364 |
+
assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
|
| 365 |
+
|
| 366 |
+
preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
|
| 367 |
+
|
| 368 |
+
# Embed the tokens
|
| 369 |
+
convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
|
| 370 |
+
|
| 371 |
+
# Construct the input
|
| 372 |
+
input_embeds = torch.cat([
|
| 373 |
+
convo_embeds[:, :preamble_len], # Part before the prompt
|
| 374 |
+
embedded_images.to(dtype=convo_embeds.dtype), # Image
|
| 375 |
+
convo_embeds[:, preamble_len:], # The prompt and anything after it
|
| 376 |
+
], dim=1).to(device)
|
| 377 |
+
|
| 378 |
+
input_ids = torch.cat([
|
| 379 |
+
convo_tokens[:preamble_len].unsqueeze(0),
|
| 380 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
|
| 381 |
+
convo_tokens[preamble_len:].unsqueeze(0),
|
| 382 |
+
], dim=1).to(device)
|
| 383 |
+
attention_mask = torch.ones_like(input_ids)
|
| 384 |
+
|
| 385 |
+
# Debugging
|
| 386 |
+
#print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
|
| 387 |
+
|
| 388 |
+
generate_ids = text_model.generate(input_ids=input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, do_sample=True,
|
| 389 |
+
suppress_tokens=None, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature)
|
| 390 |
+
|
| 391 |
+
# Trim off the prompt
|
| 392 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
| 393 |
+
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
| 394 |
+
generate_ids = generate_ids[:, :-1]
|
| 395 |
+
|
| 396 |
+
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
| 397 |
+
all_captions.append(caption.strip())
|
| 398 |
|
| 399 |
if pbar:
|
| 400 |
pbar.update(len(batch))
|