satishpednekar commited on
Commit
6cc0654
·
verified ·
1 Parent(s): 2bc71d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -7,12 +7,13 @@ MODEL_NAME = "satishpednekar/sbxcertqueryhelper"
7
 
8
  def load_model():
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
10
  model = AutoModelForCausalLM.from_pretrained(
11
  MODEL_NAME,
12
- torch_dtype=torch.float16,
13
  device_map="auto",
14
  trust_remote_code=True,
15
- load_in_8bit=True # Enable 8-bit quantization for memory efficiency
16
  )
17
  return model, tokenizer
18
 
@@ -27,7 +28,9 @@ def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.95):
27
  """
28
  try:
29
  # Prepare the input
30
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
31
 
32
  # Generate
33
  outputs = model.generate(
 
7
 
8
  def load_model():
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
10
+ # Modified model loading without 8-bit quantization
11
  model = AutoModelForCausalLM.from_pretrained(
12
  MODEL_NAME,
13
+ torch_dtype=torch.float32, # Use float32 instead of float16 for better compatibility
14
  device_map="auto",
15
  trust_remote_code=True,
16
+ # Removed load_in_8bit parameter
17
  )
18
  return model, tokenizer
19
 
 
28
  """
29
  try:
30
  # Prepare the input
31
+ inputs = tokenizer(prompt, return_tensors="pt")
32
+ if torch.cuda.is_available():
33
+ inputs = inputs.to(model.device)
34
 
35
  # Generate
36
  outputs = model.generate(