from constants import * from utils import image_to_tensor, tokenizer, tensor_to_image, vocab_size, tokenizer import torch import torch.nn.functional as F from PIL import ImageDraw, Image from dataset import create_test_dataloader from vision_language_model import VisionLanguageModel model = VisionLanguageModel( n_embd=HIDDEN_DIM, vocab_size=vocab_size, img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, num_heads=NUM_HEADS, num_blks_vit=NUM_LAYERS, # Or specific value for ViT layers num_blks_dec=NUM_LAYERS, # Or specific value for Decoder layers emb_dropout=DROPOUT, blk_dropout=DROPOUT, max_context=CONTEXT_LENGTH, shared_embed_dim=SHARED_EMBED_DIM, lambda_contrastive=LAMBDA_CONTRASTIVE, lambda_regression=LAMBDA_REGRESSION # Pass the regression weight ).to(DEVICE) MODEL_PATH = "model_regression_multi_first_100.pth" # "model_regression_multi_16.pth" if DEVICE == "cuda": model.load_state_dict(torch.load(MODEL_PATH, weights_only=True)) else: model.load_state_dict(torch.load(MODEL_PATH, weights_only=True, map_location=torch.device('cpu'))) model.eval() def generate_sample_from_image_text( model, image_path, prompt_label, tokenizer, device, max_new_tokens=70, temperature=0.8, top_k=10, output_path="generated_output.png" ): """ Generates a prediction for an image and prompt text and saves it to a file. Generation loop is implemented *within* this function. Args: model: The trained VisionLanguageModel. image_path: Path to the input image. prompt_label: Text prompt/label to use. tokenizer: The tokenizer used for training. device: The computation device ('cuda' or 'cpu'). max_new_tokens (int): Max tokens to generate after the prompt. temperature (float): Softmax temperature for sampling. top_k (int): K for top-k sampling (0 or None to disable). output_path (str): Path where to save the output image. Returns: None. Saves the image with prompt and generated output to a file. """ model.eval() # Set the model to evaluation mode try: with torch.no_grad(): # No need to track gradients during inference # --- 1. Prepare Initial Inputs --- # Load and process image image = Image.open(image_path) image_tensor = image_to_tensor(image).unsqueeze(0).to(device) # Add batch dim # Tokenize prompt prompt_text = f"{prompt_label}" prompt_tokens = tokenizer(prompt_text, return_tensors="pt", truncation=True, padding=False) prompt_ids = prompt_tokens.input_ids.to(device) prompt_attention_mask = prompt_tokens.attention_mask.to(device) B = 1 # We are processing one sample at a time print(f"--- Generating Sample (Manual Loop) ---") print(f"Original Label/Prompt Hint: {prompt_label}") print(f"Input Prompt Tokens Decoded: {prompt_text}") # --- 2. Pre-compute Image & Prompt Embeddings (Part of VLM Forward Logic) --- image_embeds_raw = model.vision_encoder(image_tensor) # (1, N_img, C) image_embeds_decoder = model.multimodal_projector(image_embeds_raw) # (1, N_img, C) prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) # (1, T_prompt, C) result_start_token_id = tokenizer.encode("", add_special_tokens=False)[0] result_start_embed = model.decoder.token_embedding_table( torch.tensor([[result_start_token_id]], device=device) # Shape (1, 1, C) ) # The initial sequence fed to the decoder blocks consists of image + prompt current_embeds = torch.cat([ image_embeds_decoder, prompt_embeds_decoder, result_start_embed # Add the embedding for the first expected output token ], dim=1) generated_ids = [] # Store newly generated IDs # --- 3. Autoregressive Generation Loop --- for _ in range(max_new_tokens): T_current = current_embeds.shape[1] # Truncate if necessary (keep recent context) if T_current > model.decoder.max_context: # Access max_context from decoder print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}") current_embeds = current_embeds[:, -model.decoder.max_context:, :] T_current = model.decoder.max_context # Prepare positional embeddings for current length pos = torch.arange(0, T_current, dtype=torch.long, device=device) pos = pos.clamp(max=model.decoder.max_context - 1) # Clamp indices pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) # (1, T_current, C) x = current_embeds + pos_emb # Create attention mask (all ones, causal handles future) # Note: We don't need padding mask here as we handle one sequence without padding attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long) # Pass through Decoder Blocks for block in model.decoder.blocks: # We assume the block forward takes (x, attention_mask) x = block(x, attention_mask=attention_mask) # Final Layer Norm and LM Head for the *last* token prediction x = model.decoder.ln_f(x[:, -1:, :]) # (B, 1, C) -> (1, 1, C) logits = model.decoder.lm_head(x) # (B, 1, V) -> (1, 1, V) logits = logits.squeeze(1) # (B, V) -> (1, V) # Sampling logits = logits / temperature if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) # idx_next = torch.multinomial(probs, num_samples=1) # (1, 1) # test distribution idx_next = torch.argmax(logits, dim=-1, keepdim=True) # test deterministic # Store generated ID generated_ids.append(idx_next) # Stop if EOS token is generated if idx_next.item() == tokenizer.eos_token_id: print("EOS token generated.") break # Prepare for next iteration: Append embedding of new token next_token_embed = model.decoder.token_embedding_table(idx_next) # (1, 1, C) current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) # Append along sequence dim # --- 4. Combine and Decode Results --- if generated_ids: generated_ids_tensor = torch.cat(generated_ids, dim=1) # (1, T_generated) initial_target_ids = torch.tensor([[result_start_token_id]], device=device) full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1) else: full_generated_sequence_ids = prompt_ids # Nothing was generated full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False) print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}") # --- 5. Save visualization to file --- save_coords_visualization( image_tensor=image_tensor[0], # Remove batch dim for visualization full_decoded_text=full_decoded_text, tokenizer=tokenizer, image_size=IMAGE_SIZE, # Assumes IMAGE_SIZE is globally defined num_bins=NUM_BINS, # Assumes NUM_BINS is globally defined output_path=output_path ) print(f"Visualization saved to: {output_path}") except Exception as e: print(f"An error occurred during sample generation: {e}") import traceback traceback.print_exc() def generate_sample_from_test_loader( model, test_loader, tokenizer, device, max_new_tokens=70, temperature=0.8, top_k=10, output_path="generated_output.png", TEST_BATCH=8, TEST_IDX=1 ): """ Generates a prediction for one sample from the test loader and saves it to a file. Generation loop is implemented *within* this function. Args: model: The trained VisionLanguageModel. test_loader: DataLoader for the test set. tokenizer: The tokenizer used for training. device: The computation device ('cuda' or 'cpu'). max_new_tokens (int): Max tokens to generate after the prompt. temperature (float): Softmax temperature for sampling. top_k (int): K for top-k sampling (0 or None to disable). output_path (str): Path where to save the output image. Returns: None. Saves the image with prompt and generated output to a file. """ if not test_loader or len(test_loader.dataset) == 0: print("Test loader is empty or not available.") return model.eval() # Set the model to evaluation mode try: # Get a single batch from the test loader with torch.no_grad(): # No need to track gradients during inference my_iter = iter(test_loader) for i in range(TEST_BATCH): _ = next(my_iter) batch = next(my_iter) if batch is None: print("Test loader yielded an empty batch.") return if batch['image'].shape[0] == 0: print("Test loader yielded a batch with 0 items.") return # --- 1. Prepare Initial Inputs --- image_tensor = batch['image'][TEST_IDX:TEST_IDX+1].to(device) # (1, 3, H, W) prompt_ids = batch['prompt_ids'][TEST_IDX:TEST_IDX+1].to(device) # (1, T_prompt) prompt_attention_mask = batch['prompt_attention_mask'][TEST_IDX:TEST_IDX+1].to(device) # (1, T_prompt) label = batch['label'][TEST_IDX] B = 1 # We are processing one sample at a time print(f"--- Generating Sample (Manual Loop) ---") print(f"Original Label/Prompt Hint: {label}") prompt_text = tokenizer.decode(prompt_ids[0], skip_special_tokens=False) print(f"Input Prompt Tokens Decoded: {prompt_text}") # --- 2. Pre-compute Image & Prompt Embeddings (Part of VLM Forward Logic) --- image_embeds_raw = model.vision_encoder(image_tensor) # (1, N_img, C) image_embeds_decoder = model.multimodal_projector(image_embeds_raw) # (1, N_img, C) prompt_embeds_decoder = model.decoder.token_embedding_table(prompt_ids) # (1, T_prompt, C) result_start_token_id = tokenizer.encode("", add_special_tokens=False)[0] result_start_embed = model.decoder.token_embedding_table( torch.tensor([[result_start_token_id]], device=device) # Shape (1, 1, C) ) # The initial sequence fed to the decoder blocks consists of image + prompt current_embeds = torch.cat([ image_embeds_decoder, prompt_embeds_decoder, result_start_embed # Add the embedding for the first expected output token ], dim=1) # current_embeds = torch.cat([image_embeds_decoder, prompt_embeds_decoder], dim=1) # (1, T_initial, C) generated_ids = [] # Store newly generated IDs # --- 3. Autoregressive Generation Loop --- for _ in range(max_new_tokens): T_current = current_embeds.shape[1] # Truncate if necessary (keep recent context) if T_current > model.decoder.max_context: # Access max_context from decoder print(f"Warning: Truncating context from {T_current} to {model.decoder.max_context}") current_embeds = current_embeds[:, -model.decoder.max_context:, :] T_current = model.decoder.max_context # Prepare positional embeddings for current length pos = torch.arange(0, T_current, dtype=torch.long, device=device) pos = pos.clamp(max=model.decoder.max_context - 1) # Clamp indices pos_emb = model.decoder.position_embedding_table(pos).unsqueeze(0) # (1, T_current, C) x = current_embeds + pos_emb # Create attention mask (all ones, causal handles future) # Note: We don't need padding mask here as we handle one sequence without padding attention_mask = torch.ones(B, T_current, device=device, dtype=torch.long) # Pass through Decoder Blocks for block in model.decoder.blocks: # We assume the block forward takes (x, attention_mask) x = block(x, attention_mask=attention_mask) # Final Layer Norm and LM Head for the *last* token prediction x = model.decoder.ln_f(x[:, -1:, :]) # (B, 1, C) -> (1, 1, C) logits = model.decoder.lm_head(x) # (B, 1, V) -> (1, 1, V) logits = logits.squeeze(1) # (B, V) -> (1, V) # Sampling logits = logits / temperature if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) # idx_next = torch.multinomial(probs, num_samples=1) # (1, 1) # test distribution idx_next = torch.argmax(logits, dim=-1, keepdim=True) # test deterministic # Store generated ID generated_ids.append(idx_next) # Stop if EOS token is generated if idx_next.item() == tokenizer.eos_token_id: print("EOS token generated.") break # Prepare for next iteration: Append embedding of new token next_token_embed = model.decoder.token_embedding_table(idx_next) # (1, 1, C) current_embeds = torch.cat([current_embeds, next_token_embed], dim=1) # Append along sequence dim # --- 4. Combine and Decode Results --- if generated_ids: generated_ids_tensor = torch.cat(generated_ids, dim=1) # (1, T_generated) initial_target_ids = torch.tensor([[result_start_token_id]], device=device) full_generated_sequence_ids = torch.cat([prompt_ids, initial_target_ids, generated_ids_tensor], dim=1) else: full_generated_sequence_ids = prompt_ids # Nothing was generated full_decoded_text = tokenizer.decode(full_generated_sequence_ids[0], skip_special_tokens=False) print(f"\nFull Generated Sequence (Manual Loop):\n{full_decoded_text}") # --- 5. Save visualization to file --- save_coords_visualization( image_tensor=image_tensor[0], # Remove batch dim for visualization full_decoded_text=full_decoded_text, tokenizer=tokenizer, image_size=IMAGE_SIZE, # Assumes IMAGE_SIZE is globally defined num_bins=NUM_BINS, # Assumes NUM_BINS is globally defined output_path=output_path ) print(f"Visualization saved to: {output_path}") except StopIteration: print("Test loader is exhausted.") except Exception as e: print(f"An error occurred during sample generation: {e}") import traceback traceback.print_exc() def parse_coordinate_tokens(text, tokenizer, num_bins): """ Parses generated text to extract coordinate bin tokens. Args: text (str): The decoded output text from the model. tokenizer: The tokenizer. num_bins (int): The number of coordinate bins used. Returns: list[tuple(int, int)]: A list of (x_bin, y_bin) tuples, or None if parsing fails. """ coords = [] try: # Basic parsing - look for the pattern x_start_token = "" x_end_token = "" y_start_token = "" y_end_token = "" result_end_token = "" # Find where the actual results start try: start_index = text.index("") + len("") except ValueError: print("Warning: not found in generated text.") return None # Find where results end try: end_index = text.index(result_end_token, start_index) except ValueError: end_index = len(text) # Use end of string if is missing print(f"Warning: {result_end_token} not found. Parsing until end of string.") current_pos = start_index while current_pos < end_index: # Find next X coordinate x_start_idx = text.find(x_start_token, current_pos) if x_start_idx == -1 or x_start_idx >= end_index: break # No more x points found x_start_idx += len(x_start_token) x_end_idx = text.find(x_end_token, x_start_idx) if x_end_idx == -1 or x_end_idx >= end_index: break # Malformed x_token_str = text[x_start_idx:x_end_idx].strip() # Find next Y coordinate (must follow X) y_start_idx = text.find(y_start_token, x_end_idx) if y_start_idx == -1 or y_start_idx >= end_index: break # No corresponding y point y_start_idx += len(y_start_token) y_end_idx = text.find(y_end_token, y_start_idx) if y_end_idx == -1 or y_end_idx >= end_index: break # Malformed y_token_str = text[y_start_idx:y_end_idx].strip() x_token_str = x_token_str[:-1] y_token_str = y_token_str[:-1] # Convert token strings to bin numbers try: x_bin = int(x_token_str.split("_")[-1]) y_bin = int(y_token_str.split("_")[-1]) if 0 <= x_bin < num_bins and 0 <= y_bin < num_bins: coords.append((x_bin, y_bin)) else: print(f"Warning: Parsed bin indices out of range ({x_bin}, {y_bin}). Skipping.") except (ValueError, IndexError): print(f"Warning: Could not parse bins from tokens '{x_token_str}', '{y_token_str}'. Skipping.") # Move search position past the found Y token current_pos = y_end_idx + len(y_end_token) return coords if coords else None except Exception as e: print(f"Error during coordinate parsing: {e}") return None def save_coords_visualization(image_tensor, full_decoded_text, tokenizer, image_size, num_bins, output_path): """Parses coords, draws them on the image, and saves to a file.""" parsed_bins = parse_coordinate_tokens(full_decoded_text, tokenizer, num_bins) # Convert tensor to PIL image for drawing try: pil_image = tensor_to_image(image_tensor.cpu()) # Ensure tensor is on CPU except Exception as e: print(f"Error converting tensor to image: {e}") # Create a placeholder image if conversion fails pil_image = Image.new('RGB', (image_size, image_size), color='white') draw = ImageDraw.Draw(pil_image) draw.text((10, 10), "Image conversion failed", fill="black") pil_image.save(output_path) return draw = ImageDraw.Draw(pil_image) radius = 5 # Radius of the drawn point if parsed_bins: print(f"\nParsed Coordinate Bins: {parsed_bins}") bin_size_pixels = image_size / num_bins for x_bin, y_bin in parsed_bins: # Calculate center of the bin in pixels center_x = (x_bin + 0.5) * bin_size_pixels center_y = (y_bin + 0.5) * bin_size_pixels # Draw a circle bbox = [center_x - radius, center_y - radius, center_x + radius, center_y + radius] draw.ellipse(bbox, outline="red", width=3) # Optional: Draw bin boundaries for debugging # draw.rectangle([x_bin*bin_size_pixels, y_bin*bin_size_pixels, (x_bin+1)*bin_size_pixels, (y_bin+1)*bin_size_pixels], outline="blue", width=1) # Add a text label with the coordinates at the top of the image coord_text = f"Generated Point(s): {parsed_bins}" draw.text((10, 10), coord_text, fill="red") else: print("\nCould not parse valid coordinates from the generated text.") # Add a text label indicating no coordinates were found draw.text((10, 10), "No Coordinates Parsed", fill="red") # Save the image to file pil_image.save(output_path) import argparse # --- Example Usage --- # python infer.py --image ./data/test_images/image_1.png --prompt "a red apple" if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--image', type=str, help='Path to input image') parser.add_argument('--prompt', type=str, help='Prompt label for generation') args = parser.parse_args() if args.image and args.prompt: # Use image and prompt based generation if 'model' in locals() and 'tokenizer' in locals(): generate_sample_from_image_text( model=model, image_path=args.image, prompt_label=args.prompt, tokenizer=tokenizer, device=DEVICE, output_path="model_prediction.png" ) else: print("Please ensure 'model' and 'tokenizer' are loaded before running generation.") else: # Use test loader based generation if 'model' in locals() and 'test_loader' in locals() and 'tokenizer' in locals(): test_loader = create_test_dataloader(batch_size=2, num_workers=0) generate_sample_from_test_loader( model=model, test_loader=test_loader, tokenizer=tokenizer, device=DEVICE, output_path="model_prediction.png" ) else: print("Please ensure 'model', 'test_loader', and 'tokenizer' are loaded before running generation.")