jdesiree commited on
Commit
79d5341
·
verified ·
1 Parent(s): e13b10c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -43
app.py CHANGED
@@ -261,52 +261,52 @@ class Qwen25SmallLLM(Runnable):
261
  """LLM class that properly inherits from Runnable for LangChain compatibility"""
262
 
263
  def __init__(self, model_path: str = "Qwen/Qwen2.5-3B-Instruct", use_4bit: bool = True):
264
- super().__init__()
265
- logger.info(f"Loading model: {model_path} (use_4bit={use_4bit})")
266
- start_Loading_Model_time = time.perf_counter()
267
- current_time = datetime.now()
268
 
269
- try:
270
- # Load tokenizer
271
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
272
-
273
- if use_4bit:
274
- quant_config = BitsAndBytesConfig(
275
- load_in_4bit=True,
276
- bnb_4bit_compute_dtype=torch.float16,
277
- bnb_4bit_use_double_quant=True,
278
- bnb_4bit_quant_type="nf4",
279
- llm_int8_threshold=0.0,
280
- llm_int8_skip_modules=["lm_head"]
281
- )
282
 
283
- # Try quantized load with updated dtype parameter
284
- self.model = AutoModelForCausalLM.from_pretrained(
285
- model_path,
286
- quantization_config=quant_config,
287
- device_map="auto",
288
- dtype=torch.bfloat16,
289
- trust_remote_code=True,
290
- low_cpu_mem_usage=True
291
- )
292
- else:
293
- self._load_fallback_model(model_path)
 
 
 
 
 
294
 
295
- # Success path - log timing
296
- end_Loading_Model_time = time.perf_counter()
297
- Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
298
- log_metric(f"Model Load time: {Loading_Model_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
299
-
300
- except Exception as e:
301
- logger.warning(f"Quantized load failed, falling back: {e}")
302
- self._load_fallback_model(model_path)
303
- end_Loading_Model_time = time.perf_counter()
304
- Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
305
- log_metric(f"Model Load time (fallback): {Loading_Model_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
306
-
307
- # Ensure pad token
308
- if self.tokenizer.pad_token is None:
309
- self.tokenizer.pad_token = self.tokenizer.eos_token
310
 
311
  def _load_fallback_model(self, model_path: str):
312
  """Fallback if quantization fails."""
 
261
  """LLM class that properly inherits from Runnable for LangChain compatibility"""
262
 
263
  def __init__(self, model_path: str = "Qwen/Qwen2.5-3B-Instruct", use_4bit: bool = True):
264
+ super().__init__()
265
+ logger.info(f"Loading model: {model_path} (use_4bit={use_4bit})")
266
+ start_Loading_Model_time = time.perf_counter()
267
+ current_time = datetime.now()
268
 
269
+ try:
270
+ # Load tokenizer
271
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
272
+
273
+ if use_4bit:
274
+ quant_config = BitsAndBytesConfig(
275
+ load_in_4bit=True,
276
+ bnb_4bit_compute_dtype=torch.float16,
277
+ bnb_4bit_use_double_quant=True,
278
+ bnb_4bit_quant_type="nf4",
279
+ llm_int8_threshold=0.0,
280
+ llm_int8_skip_modules=["lm_head"]
281
+ )
282
 
283
+ # Try quantized load with updated dtype parameter
284
+ self.model = AutoModelForCausalLM.from_pretrained(
285
+ model_path,
286
+ quantization_config=quant_config,
287
+ device_map="auto",
288
+ dtype=torch.bfloat16,
289
+ trust_remote_code=True,
290
+ low_cpu_mem_usage=True
291
+ )
292
+ else:
293
+ self._load_fallback_model(model_path)
294
+
295
+ # Success path - log timing
296
+ end_Loading_Model_time = time.perf_counter()
297
+ Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
298
+ log_metric(f"Model Load time: {Loading_Model_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
299
 
300
+ except Exception as e:
301
+ logger.warning(f"Quantized load failed, falling back: {e}")
302
+ self._load_fallback_model(model_path)
303
+ end_Loading_Model_time = time.perf_counter()
304
+ Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
305
+ log_metric(f"Model Load time (fallback): {Loading_Model_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
306
+
307
+ # Ensure pad token
308
+ if self.tokenizer.pad_token is None:
309
+ self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
 
 
310
 
311
  def _load_fallback_model(self, model_path: str):
312
  """Fallback if quantization fails."""