Jayashree Sridhar commited on
Commit
f8a9066
·
1 Parent(s): 94ae986

replaced mistral with TinyGPT2Model

Browse files
Files changed (1) hide show
  1. models/mistral_model.py +4 -5
models/mistral_model.py CHANGED
@@ -24,9 +24,9 @@ class MistralModel:
24
 
25
  def _initialize_model(self):
26
  """Initialize Mistral model with optimizations"""
27
- print("Loading Mistral model...")
28
 
29
- model_id = "mistralai/Mistral-7B-Instruct-v0.2"
30
 
31
  # Load tokenizer
32
  MistralModel._tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACE_TOKEN,use_fast=False)
@@ -35,12 +35,11 @@ class MistralModel:
35
  MistralModel._model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
  token=HUGGINGFACE_TOKEN,
38
- torch_dtype=torch.float16,
39
- device_map="auto",
40
  load_in_8bit=True # Use 8-bit quantization for memory efficiency
41
  )
42
 
43
- print("Mistral model loaded successfully!")
44
 
45
  def generate(
46
  self,
 
24
 
25
  def _initialize_model(self):
26
  """Initialize Mistral model with optimizations"""
27
+ print("Loading TinyGPT2Model model...")
28
 
29
+ model_id = "sshleifer/tiny-gpt2"
30
 
31
  # Load tokenizer
32
  MistralModel._tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACE_TOKEN,use_fast=False)
 
35
  MistralModel._model = AutoModelForCausalLM.from_pretrained(
36
  model_id,
37
  token=HUGGINGFACE_TOKEN,
38
+ torch_dtype=torch.float32,
 
39
  load_in_8bit=True # Use 8-bit quantization for memory efficiency
40
  )
41
 
42
+ print("TinyGPT2Model loaded successfully!")
43
 
44
  def generate(
45
  self,