|
|
|
|
|
""" |
|
|
Example usage of Braille256-v1 model. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from braille256_model import Braille256Model, Braille256Config |
|
|
from braille256_tokenizer import Braille256Tokenizer |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
model_path = "." |
|
|
|
|
|
config = Braille256Config.from_pretrained(model_path) |
|
|
model = Braille256Model.from_pretrained(model_path, config=config) |
|
|
tokenizer = Braille256Tokenizer.from_pretrained(model_path) |
|
|
|
|
|
model.eval() |
|
|
print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters") |
|
|
|
|
|
|
|
|
prompts = [ |
|
|
("⠞⠓⠑", "the"), |
|
|
("⠁⠝⠙", "and"), |
|
|
("⠋⠕⠗", "for"), |
|
|
] |
|
|
|
|
|
print("\nGeneration examples:") |
|
|
print("-" * 50) |
|
|
|
|
|
for braille_prompt, english in prompts: |
|
|
tokens = tokenizer.encode(braille_prompt, add_special_tokens=True) |
|
|
input_ids = torch.tensor([tokens]) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate(input_ids, max_length=50, temperature=0.7) |
|
|
|
|
|
generated = tokenizer.decode(output[0].tolist()) |
|
|
print(f"Prompt: {english} ({braille_prompt})") |
|
|
print(f"Generated: {generated[:60]}...") |
|
|
print() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|