| | import torch
|
| | import argparse
|
| | from __init__ import QuillanSOTA, Config
|
| | import os
|
| |
|
| | def load_model(checkpoint_path: str, device: str = 'cuda'):
|
| | print(f"Loading model from {checkpoint_path}...")
|
| |
|
| | config = Config()
|
| | model = QuillanSOTA(config)
|
| |
|
| | if os.path.exists(checkpoint_path):
|
| | state_dict = torch.load(checkpoint_path, map_location=device)
|
| | model.load_state_dict(state_dict)
|
| | print("Checkpoint loaded successfully.")
|
| | else:
|
| | print(f"Warning: Checkpoint {checkpoint_path} not found. Using random weights.")
|
| |
|
| | model.to(device)
|
| | model.eval()
|
| | return model
|
| |
|
| | def generate_text(model, prompt_text, max_new_tokens=50, temperature=0.7, device='cuda'):
|
| | """Simple text generation for demonstration"""
|
| | print(f"Generating text with prompt: '{prompt_text}'")
|
| |
|
| |
|
| |
|
| | batch_size = 1
|
| | seq_len = 10
|
| |
|
| |
|
| | patch_size = 16
|
| |
|
| | grid_size = 256
|
| | img_size = grid_size * patch_size
|
| |
|
| | text_input = torch.randint(0, 1000, (batch_size, seq_len)).to(device)
|
| | img_input = torch.randn(batch_size, 3, img_size, img_size).to(device)
|
| | aud_input = torch.randn(batch_size, 1, 1024).to(device)
|
| | vid_input = torch.randn(batch_size, 3, 16, img_size, img_size).to(device)
|
| |
|
| | with torch.no_grad():
|
| |
|
| | outputs = model(text_input, img_input, aud_input, vid_input)
|
| |
|
| |
|
| | if isinstance(outputs, dict):
|
| | text_logits = outputs.get('text', outputs.get('logits', None))
|
| | else:
|
| | text_logits = outputs
|
| |
|
| |
|
| | if text_logits is not None and isinstance(text_logits, torch.Tensor):
|
| | if text_logits.dim() == 3:
|
| |
|
| | last_logits = text_logits[0, -1, :]
|
| |
|
| |
|
| | scaled_logits = last_logits / temperature
|
| |
|
| |
|
| | probs = torch.softmax(scaled_logits, dim=-1)
|
| | next_token = torch.multinomial(probs, 1)
|
| |
|
| | return [next_token.item()]
|
| | else:
|
| | return [text_logits.argmax().item()]
|
| | else:
|
| |
|
| | return [42]
|
| |
|
| | if __name__ == "__main__":
|
| | parser = argparse.ArgumentParser(description="Quillan Inference")
|
| | parser.add_argument("--checkpoint", type=str, default="checkpoints/quillan_final.pt", help="Path to model checkpoint")
|
| | parser.add_argument("--prompt", type=str, default="Hello, world!", help="Input prompt")
|
| | parser.add_argument("--max_tokens", type=int, default=50, help="Max new tokens to generate")
|
| | parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
|
| |
|
| | args = parser.parse_args()
|
| |
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| | model = load_model(args.checkpoint, device)
|
| |
|
| |
|
| | output_ids = generate_text(model, args.prompt, args.max_tokens, args.temperature, device)
|
| | print(f"Generated token IDs: {output_ids}")
|
| | print("Note: This is a demonstration with mock multimodal inputs.")
|
| | print("For actual text generation, you'll need to implement proper tokenization and handle the multimodal outputs.")
|
| |
|