|
|
import torch |
|
|
from lcm.utils.card_utils import load_model_from_card |
|
|
from lcm.inference.two_tower_diffusion_lcm import TwoTowerDiffusionLCMGenerator, DiffusionLCMGeneratorOptions |
|
|
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline |
|
|
from lcm.datasets.batch import EmbeddingsBatch |
|
|
|
|
|
def main(): |
|
|
|
|
|
device = torch.device("cuda:0") |
|
|
|
|
|
|
|
|
model_card = "./_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml" |
|
|
model = load_model_from_card(model_card, device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
options = DiffusionLCMGeneratorOptions( |
|
|
guidance_scale=3.0, |
|
|
guidance_rescale=0.0, |
|
|
ddim_eta=0.0, |
|
|
initial_noise_scale=1.0, |
|
|
inference_timesteps=100, |
|
|
clip_noise=100, |
|
|
thresholding=False, |
|
|
dynamic_thresholding_ratio=0.995, |
|
|
sample_max_value=6.0 |
|
|
) |
|
|
|
|
|
|
|
|
generator = TwoTowerDiffusionLCMGenerator(model=model, options=options) |
|
|
|
|
|
|
|
|
text_encoder = TextToEmbeddingModelPipeline( |
|
|
encoder="text_sonar_basic_encoder", |
|
|
tokenizer="text_sonar_basic_encoder", |
|
|
device=device, |
|
|
dtype=torch.float32 |
|
|
) |
|
|
|
|
|
text_decoder = EmbeddingToTextModelPipeline( |
|
|
decoder="text_sonar_basic_decoder", |
|
|
tokenizer="text_sonar_basic_decoder", |
|
|
device=device, |
|
|
dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
eos_text = "End of text." |
|
|
eos_embedding = text_encoder.predict([eos_text], source_lang="eng_Latn") |
|
|
generator.eos_vec = eos_embedding.squeeze(0) |
|
|
|
|
|
|
|
|
prompts = [ |
|
|
["Petals fall in the wind.", "They swirl and dance and float away.", "Then all becomes still again."], |
|
|
["Like whisps of light, the moonlight meets the rolling brook.", "Upon seeing it glimmer, she turns and smiles.", "Her friend was glad they could share this tranquil moment."], |
|
|
["Tokyo is the modern capital of Japan.", "Although it is currently the case, historically, cities such as Kyoto and Nara have also served as the capital."] |
|
|
] |
|
|
|
|
|
print("\nProcessing prompts:") |
|
|
all_prompt_embeddings = [] |
|
|
max_sentences = max(len(prompt) for prompt in prompts) |
|
|
|
|
|
|
|
|
for i, prompt_sentences in enumerate(prompts): |
|
|
print(f"\nPrompt {i+1}:") |
|
|
for sentence in prompt_sentences: |
|
|
print(f" {sentence}") |
|
|
|
|
|
|
|
|
sentence_embeddings = text_encoder.predict(prompt_sentences, source_lang="eng_Latn") |
|
|
print(f" Sentence embeddings shape: {sentence_embeddings.shape}") |
|
|
|
|
|
|
|
|
if len(prompt_sentences) < max_sentences: |
|
|
padding = torch.zeros((max_sentences - len(prompt_sentences), sentence_embeddings.shape[1]), |
|
|
device=device, dtype=sentence_embeddings.dtype) |
|
|
sentence_embeddings = torch.cat([sentence_embeddings, padding], dim=0) |
|
|
|
|
|
all_prompt_embeddings.append(sentence_embeddings) |
|
|
|
|
|
|
|
|
prompt_embeddings = torch.stack(all_prompt_embeddings) |
|
|
print("\nFinal batch shape:", prompt_embeddings.shape) |
|
|
|
|
|
batch = EmbeddingsBatch(prompt_embeddings, None) |
|
|
|
|
|
|
|
|
output = generator( |
|
|
batch_input=batch, |
|
|
max_gen_len=24, |
|
|
min_gen_len=10, |
|
|
temperature=0.1 |
|
|
) |
|
|
|
|
|
|
|
|
print("\nGenerated outputs:") |
|
|
for i, hypotheses in enumerate(output.hypotheses): |
|
|
print(f"\nOutput for Prompt {i+1}:") |
|
|
for hypothesis in hypotheses: |
|
|
generated_text = text_decoder.predict( |
|
|
hypothesis.seq, |
|
|
target_lang="eng_Latn", |
|
|
max_seq_len=256, |
|
|
temperature=1.0 |
|
|
) |
|
|
print("Generated text:", generated_text) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |