| from ssllm_hf import SSLLMForCausalLM, SSLLMConfig |
| import tiktoken |
| import torch |
| from safetensors.torch import load_file |
| from huggingface_hub import hf_hub_download |
|
|
| |
| config = SSLLMConfig.from_pretrained('sausheong/ssllm_hf') |
| model = SSLLMForCausalLM(config) |
|
|
| |
| model_path = hf_hub_download(repo_id='sausheong/ssllm_hf', filename='model.safetensors') |
| state_dict = load_file(model_path) |
| model.load_state_dict(state_dict, strict=False) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model = model.to(device).eval() |
|
|
| |
| tokenizer = tiktoken.get_encoding('cl100k_base') |
|
|
| def generate_text(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=40): |
| |
| input_ids = torch.tensor([tokenizer.encode(prompt)], device=device) |
| attention_mask = torch.ones_like(input_ids) |
| |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| pad_token_id=100257, |
| eos_token_id=100257, |
| ) |
| |
| |
| new_tokens = outputs[0][input_ids.shape[1]:].tolist() |
| generated = tokenizer.decode(new_tokens) |
| |
| print(f"{prompt}{generated}") |
| print(f"\nTokens generated: {len(new_tokens)}") |
|
|
| if __name__ == "__main__": |
| prompt = "In a small village nestled between mountains," |
| print(f"PROMPT: {prompt}\n--") |
| generate_text(prompt) |
|
|