import torch from lcm.utils.card_utils import load_model_from_card from lcm.inference.two_tower_diffusion_lcm import TwoTowerDiffusionLCMGenerator, DiffusionLCMGeneratorOptions from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline from lcm.datasets.batch import EmbeddingsBatch def main(): # Setup device device = torch.device("cuda:0") # Load model model_card = "./_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml" model = load_model_from_card(model_card, device=device, dtype=torch.float32) # Setup generator options options = DiffusionLCMGeneratorOptions( guidance_scale=3.0, # Increased from 1.0 to make generation more focused guidance_rescale=0.0, ddim_eta=0.0, initial_noise_scale=1.0, inference_timesteps=100, clip_noise=100, thresholding=False, dynamic_thresholding_ratio=0.995, sample_max_value=6.0 ) # Create generator generator = TwoTowerDiffusionLCMGenerator(model=model, options=options) # Setup text encoders/decoders text_encoder = TextToEmbeddingModelPipeline( encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder", device=device, dtype=torch.float32 ) text_decoder = EmbeddingToTextModelPipeline( decoder="text_sonar_basic_decoder", tokenizer="text_sonar_basic_decoder", device=device, dtype=torch.float32 ) # Get EOS embedding eos_text = "End of text." eos_embedding = text_encoder.predict([eos_text], source_lang="eng_Latn") generator.eos_vec = eos_embedding.squeeze(0) # Remove batch dimension # Example prompts (each inner list is a multi-sentence prompt) prompts = [ ["Petals fall in the wind.", "They swirl and dance and float away.", "Then all becomes still again."], ["Like whisps of light, the moonlight meets the rolling brook.", "Upon seeing it glimmer, she turns and smiles.", "Her friend was glad they could share this tranquil moment."], ["Tokyo is the modern capital of Japan.", "Although it is currently the case, historically, cities such as Kyoto and Nara have also served as the capital."] ] print("\nProcessing prompts:") all_prompt_embeddings = [] max_sentences = max(len(prompt) for prompt in prompts) # Process each multi-sentence prompt for i, prompt_sentences in enumerate(prompts): print(f"\nPrompt {i+1}:") for sentence in prompt_sentences: print(f" {sentence}") # Encode each sentence separately sentence_embeddings = text_encoder.predict(prompt_sentences, source_lang="eng_Latn") print(f" Sentence embeddings shape: {sentence_embeddings.shape}") # Pad to max_sentences if needed if len(prompt_sentences) < max_sentences: padding = torch.zeros((max_sentences - len(prompt_sentences), sentence_embeddings.shape[1]), device=device, dtype=sentence_embeddings.dtype) sentence_embeddings = torch.cat([sentence_embeddings, padding], dim=0) all_prompt_embeddings.append(sentence_embeddings) # Stack all prompts into a batch prompt_embeddings = torch.stack(all_prompt_embeddings) print("\nFinal batch shape:", prompt_embeddings.shape) batch = EmbeddingsBatch(prompt_embeddings, None) # Generate output = generator( batch_input=batch, max_gen_len=24, min_gen_len=10, temperature=0.1 ) # Decode generated embeddings print("\nGenerated outputs:") for i, hypotheses in enumerate(output.hypotheses): print(f"\nOutput for Prompt {i+1}:") for hypothesis in hypotheses: generated_text = text_decoder.predict( hypothesis.seq, target_lang="eng_Latn", max_seq_len=256, temperature=1.0 ) print("Generated text:", generated_text) if __name__ == "__main__": main()