| from mistral_inference.model import Transformer | |
| from mistral_inference.generate import generate | |
| from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
| from mistral_common.protocol.instruct.messages import UserMessage | |
| from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
| def main(): | |
| tokenizer = MistralTokenizer.from_file("model/tokenizer.model.v3") | |
| model = Transformer.from_folder("model") | |
| model.load_lora("lora/lora.safetensors") | |
| completion_request = ChatCompletionRequest(messages=[UserMessage(content="Explain Machine Learning to me in a nutshell.")]) | |
| tokens = tokenizer.encode_chat_completion(completion_request).tokens | |
| out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) | |
| result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0]) | |
| print(result) | |
| if __name__ == "__main__": | |
| main() | |