Spaces:
Sleeping
Sleeping
| # inference.py | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | |
| class MistralChat: | |
| def __init__(self, model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0"): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Loading model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| device_map="auto" if self.device == "cuda" else None, | |
| trust_remote_code=True | |
| ) | |
| if self.device == "cuda": | |
| self.model = self.model.to(self.device) | |
| print("Model loaded successfully!") | |
| def generate(self, prompt, max_length=500, temperature=0.7): | |
| # Format for instruct models | |
| formatted_prompt = f"[INST] {prompt} [/INST]" | |
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt") | |
| if self.device == "cuda": | |
| inputs = inputs.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.95, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's response | |
| if "[/INST]" in response: | |
| response = response.split("[/INST]")[1].strip() | |
| return response | |
| def chat_stream(self, prompt): | |
| """Stream the response token by token""" | |
| formatted_prompt = f"[INST] {prompt} [/INST]" | |
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt") | |
| streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| if self.device == "cuda": | |
| inputs = inputs.to(self.device) | |
| _ = self.model.generate(**inputs, streamer=streamer, max_new_tokens=500) | |
| # Usage | |
| if __name__ == "__main__": | |
| chat = MistralChat() | |
| # Single response | |
| response = chat.generate("Explain quantum computing in simple terms") | |
| print("Response:", response) | |
| # Streaming response | |
| print("\nStreaming response:") | |
| chat.chat_stream("Write a short poem about AI") |