LexaLCM_Pre0 / scripts /run_inference.py
Lexa
Initial commit
3d79eb3
import torch
from lcm.utils.card_utils import load_model_from_card
from lcm.inference.two_tower_diffusion_lcm import TwoTowerDiffusionLCMGenerator, DiffusionLCMGeneratorOptions
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline
from lcm.datasets.batch import EmbeddingsBatch
def main():
# Setup device
device = torch.device("cuda:0")
# Load model
model_card = "./_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml"
model = load_model_from_card(model_card, device=device, dtype=torch.float32)
# Setup generator options
options = DiffusionLCMGeneratorOptions(
guidance_scale=3.0, # Increased from 1.0 to make generation more focused
guidance_rescale=0.0,
ddim_eta=0.0,
initial_noise_scale=1.0,
inference_timesteps=100,
clip_noise=100,
thresholding=False,
dynamic_thresholding_ratio=0.995,
sample_max_value=6.0
)
# Create generator
generator = TwoTowerDiffusionLCMGenerator(model=model, options=options)
# Setup text encoders/decoders
text_encoder = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder",
device=device,
dtype=torch.float32
)
text_decoder = EmbeddingToTextModelPipeline(
decoder="text_sonar_basic_decoder",
tokenizer="text_sonar_basic_decoder",
device=device,
dtype=torch.float32
)
# Get EOS embedding
eos_text = "End of text."
eos_embedding = text_encoder.predict([eos_text], source_lang="eng_Latn")
generator.eos_vec = eos_embedding.squeeze(0) # Remove batch dimension
# Example prompts (each inner list is a multi-sentence prompt)
prompts = [
["Petals fall in the wind.", "They swirl and dance and float away.", "Then all becomes still again."],
["Like whisps of light, the moonlight meets the rolling brook.", "Upon seeing it glimmer, she turns and smiles.", "Her friend was glad they could share this tranquil moment."],
["Tokyo is the modern capital of Japan.", "Although it is currently the case, historically, cities such as Kyoto and Nara have also served as the capital."]
]
print("\nProcessing prompts:")
all_prompt_embeddings = []
max_sentences = max(len(prompt) for prompt in prompts)
# Process each multi-sentence prompt
for i, prompt_sentences in enumerate(prompts):
print(f"\nPrompt {i+1}:")
for sentence in prompt_sentences:
print(f" {sentence}")
# Encode each sentence separately
sentence_embeddings = text_encoder.predict(prompt_sentences, source_lang="eng_Latn")
print(f" Sentence embeddings shape: {sentence_embeddings.shape}")
# Pad to max_sentences if needed
if len(prompt_sentences) < max_sentences:
padding = torch.zeros((max_sentences - len(prompt_sentences), sentence_embeddings.shape[1]),
device=device, dtype=sentence_embeddings.dtype)
sentence_embeddings = torch.cat([sentence_embeddings, padding], dim=0)
all_prompt_embeddings.append(sentence_embeddings)
# Stack all prompts into a batch
prompt_embeddings = torch.stack(all_prompt_embeddings)
print("\nFinal batch shape:", prompt_embeddings.shape)
batch = EmbeddingsBatch(prompt_embeddings, None)
# Generate
output = generator(
batch_input=batch,
max_gen_len=24,
min_gen_len=10,
temperature=0.1
)
# Decode generated embeddings
print("\nGenerated outputs:")
for i, hypotheses in enumerate(output.hypotheses):
print(f"\nOutput for Prompt {i+1}:")
for hypothesis in hypotheses:
generated_text = text_decoder.predict(
hypothesis.seq,
target_lang="eng_Latn",
max_seq_len=256,
temperature=1.0
)
print("Generated text:", generated_text)
if __name__ == "__main__":
main()