TatTwamAI / models /mistral_model.py
Jayashree Sridhar
First Version
20d720d
raw
history blame
3.28 kB
"""
Mistral Model Wrapper for easy integration
"""
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Optional
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
class MistralModel:
"""Wrapper for Mistral model with caching and optimization"""
_instance = None
_model = None
_tokenizer = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if MistralModel._model is None:
self._initialize_model()
def _initialize_model(self):
"""Initialize Mistral model with optimizations"""
print("Loading Mistral model...")
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
# Load tokenizer
MistralModel._tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACE_TOKEN)
# Load model with optimizations
MistralModel._model = AutoModelForCausalLM.from_pretrained(
model_id,
token=HUGGINGFACE_TOKEN,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True # Use 8-bit quantization for memory efficiency
)
print("Mistral model loaded successfully!")
def generate(
self,
prompt: str,
max_length: int = 512,
temperature: float = 0.7,
top_p: float = 0.95
) -> str:
"""Generate response from Mistral"""
# Format prompt for Mistral instruction format
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
# Tokenize
inputs = MistralModel._tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=2048
)
# Move to device
device = next(MistralModel._model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = MistralModel._model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=MistralModel._tokenizer.eos_token_id
)
# Decode
response = MistralModel._tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response.strip()
def generate_embedding(self, text: str) -> torch.Tensor:
"""Generate embeddings for text"""
inputs = MistralModel._tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512
)
device = next(MistralModel._model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = MistralModel._model(**inputs, output_hidden_states=True)
# Use last hidden state as embedding
embeddings = outputs.hidden_states[-1].mean(dim=1)
return embeddings