Solarum Asteridion commited on
Commit
ac360a6
·
verified ·
1 Parent(s): 6f9b73a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -44,10 +44,19 @@ class LocalLLMHandler:
44
  torch.cuda.empty_cache()
45
  gc.collect()
46
 
 
 
 
 
 
 
 
 
47
  model_kwargs = {
48
- "device_map": "cpu",
49
  "torch_dtype": torch.bfloat16,
50
  "low_cpu_mem_usage": True,
 
51
  }
52
 
53
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -63,7 +72,7 @@ class LocalLLMHandler:
63
 
64
  def generate_response(self, prompt, max_length=500):
65
  try:
66
- inputs = self.tokenizer(prompt, return_tensors="pt")
67
  outputs = self.model.generate(
68
  inputs["input_ids"],
69
  max_length=max_length,
@@ -78,6 +87,7 @@ class LocalLLMHandler:
78
  logger.error(f"Error generating response: {e}")
79
  return f"Error generating response: {str(e)}"
80
 
 
81
  def get_current_local_time(timezone_str='UTC'):
82
  try:
83
  timezone = pytz.timezone(timezone_str)
 
44
  torch.cuda.empty_cache()
45
  gc.collect()
46
 
47
+ # Quantization for faster inference
48
+ quantization_config = BitsAndBytesConfig(
49
+ load_in_4bit=True,
50
+ bnb_4bit_use_double_quant=True,
51
+ bnb_4bit_quant_type="nf4",
52
+ bnb_4bit_compute_dtype=torch.bfloat16
53
+ )
54
+
55
  model_kwargs = {
56
+ "device_map": "auto", # Use GPU if available, otherwise CPU
57
  "torch_dtype": torch.bfloat16,
58
  "low_cpu_mem_usage": True,
59
+ "quantization_config": quantization_config
60
  }
61
 
62
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
72
 
73
  def generate_response(self, prompt, max_length=500):
74
  try:
75
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # Move inputs to the same device as the model
76
  outputs = self.model.generate(
77
  inputs["input_ids"],
78
  max_length=max_length,
 
87
  logger.error(f"Error generating response: {e}")
88
  return f"Error generating response: {str(e)}"
89
 
90
+
91
  def get_current_local_time(timezone_str='UTC'):
92
  try:
93
  timezone = pytz.timezone(timezone_str)