File size: 4,182 Bytes
3d79eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from lcm.utils.card_utils import load_model_from_card
from lcm.inference.two_tower_diffusion_lcm import TwoTowerDiffusionLCMGenerator, DiffusionLCMGeneratorOptions
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline
from lcm.datasets.batch import EmbeddingsBatch

def main():
    # Setup device
    device = torch.device("cuda:0")
    
    # Load model
    model_card = "./_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml"
    model = load_model_from_card(model_card, device=device, dtype=torch.float32)
    
    # Setup generator options
    options = DiffusionLCMGeneratorOptions(
        guidance_scale=3.0,  # Increased from 1.0 to make generation more focused
        guidance_rescale=0.0,
        ddim_eta=0.0,
        initial_noise_scale=1.0,
        inference_timesteps=100,
        clip_noise=100,
        thresholding=False,
        dynamic_thresholding_ratio=0.995,
        sample_max_value=6.0
    )
    
    # Create generator
    generator = TwoTowerDiffusionLCMGenerator(model=model, options=options)
    
    # Setup text encoders/decoders
    text_encoder = TextToEmbeddingModelPipeline(
        encoder="text_sonar_basic_encoder",
        tokenizer="text_sonar_basic_encoder",
        device=device,
        dtype=torch.float32
    )
    
    text_decoder = EmbeddingToTextModelPipeline(
        decoder="text_sonar_basic_decoder",
        tokenizer="text_sonar_basic_decoder",
        device=device,
        dtype=torch.float32
    )
    
    # Get EOS embedding
    eos_text = "End of text."
    eos_embedding = text_encoder.predict([eos_text], source_lang="eng_Latn")
    generator.eos_vec = eos_embedding.squeeze(0)  # Remove batch dimension
    
    # Example prompts (each inner list is a multi-sentence prompt)
    prompts = [
        ["Petals fall in the wind.", "They swirl and dance and float away.", "Then all becomes still again."],
        ["Like whisps of light, the moonlight meets the rolling brook.", "Upon seeing it glimmer, she turns and smiles.", "Her friend was glad they could share this tranquil moment."],
        ["Tokyo is the modern capital of Japan.", "Although it is currently the case, historically, cities such as Kyoto and Nara have also served as the capital."]
    ]
    
    print("\nProcessing prompts:")
    all_prompt_embeddings = []
    max_sentences = max(len(prompt) for prompt in prompts)
    
    # Process each multi-sentence prompt
    for i, prompt_sentences in enumerate(prompts):
        print(f"\nPrompt {i+1}:")
        for sentence in prompt_sentences:
            print(f"  {sentence}")
        
        # Encode each sentence separately
        sentence_embeddings = text_encoder.predict(prompt_sentences, source_lang="eng_Latn")
        print(f"  Sentence embeddings shape: {sentence_embeddings.shape}")
        
        # Pad to max_sentences if needed
        if len(prompt_sentences) < max_sentences:
            padding = torch.zeros((max_sentences - len(prompt_sentences), sentence_embeddings.shape[1]), 
                                device=device, dtype=sentence_embeddings.dtype)
            sentence_embeddings = torch.cat([sentence_embeddings, padding], dim=0)
        
        all_prompt_embeddings.append(sentence_embeddings)
    
    # Stack all prompts into a batch
    prompt_embeddings = torch.stack(all_prompt_embeddings)
    print("\nFinal batch shape:", prompt_embeddings.shape)
    
    batch = EmbeddingsBatch(prompt_embeddings, None)
    
    # Generate
    output = generator(
        batch_input=batch,
        max_gen_len=24,
        min_gen_len=10,
        temperature=0.1
    )
    
    # Decode generated embeddings
    print("\nGenerated outputs:")
    for i, hypotheses in enumerate(output.hypotheses):
        print(f"\nOutput for Prompt {i+1}:")
        for hypothesis in hypotheses:
            generated_text = text_decoder.predict(
                hypothesis.seq,
                target_lang="eng_Latn",
                max_seq_len=256,
                temperature=1.0
            )
            print("Generated text:", generated_text)

if __name__ == "__main__":
    main()