Spaces:
Sleeping
Sleeping
File size: 2,508 Bytes
1e639fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | # inference.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
class MistralChat:
def __init__(self, model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading model...")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
trust_remote_code=True
)
if self.device == "cuda":
self.model = self.model.to(self.device)
print("Model loaded successfully!")
def generate(self, prompt, max_length=500, temperature=0.7):
# Format for instruct models
formatted_prompt = f"[INST] {prompt} [/INST]"
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
if self.device == "cuda":
inputs = inputs.to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True,
top_p=0.95,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "[/INST]" in response:
response = response.split("[/INST]")[1].strip()
return response
def chat_stream(self, prompt):
"""Stream the response token by token"""
formatted_prompt = f"[INST] {prompt} [/INST]"
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
if self.device == "cuda":
inputs = inputs.to(self.device)
_ = self.model.generate(**inputs, streamer=streamer, max_new_tokens=500)
# Usage
if __name__ == "__main__":
chat = MistralChat()
# Single response
response = chat.generate("Explain quantum computing in simple terms")
print("Response:", response)
# Streaming response
print("\nStreaming response:")
chat.chat_stream("Write a short poem about AI") |