Mistral_Test / inference.py
eesfeg's picture
Add application file
1e639fb
# inference.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
class MistralChat:
def __init__(self, model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading model...")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
trust_remote_code=True
)
if self.device == "cuda":
self.model = self.model.to(self.device)
print("Model loaded successfully!")
def generate(self, prompt, max_length=500, temperature=0.7):
# Format for instruct models
formatted_prompt = f"[INST] {prompt} [/INST]"
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
if self.device == "cuda":
inputs = inputs.to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.95,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "[/INST]" in response:
response = response.split("[/INST]")[1].strip()
return response
def chat_stream(self, prompt):
"""Stream the response token by token"""
formatted_prompt = f"[INST] {prompt} [/INST]"
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
if self.device == "cuda":
inputs = inputs.to(self.device)
_ = self.model.generate(**inputs, streamer=streamer, max_new_tokens=500)
# Usage
if __name__ == "__main__":
chat = MistralChat()
# Single response
response = chat.generate("Explain quantum computing in simple terms")
print("Response:", response)
# Streaming response
print("\nStreaming response:")
chat.chat_stream("Write a short poem about AI")