ayush2917 commited on
Commit
6b02199
·
verified ·
1 Parent(s): c7793c1

Update src/generation.py

Browse files
Files changed (1) hide show
  1. src/generation.py +6 -0
src/generation.py CHANGED
@@ -1,6 +1,8 @@
1
  import logging
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from typing import List, Dict
 
4
 
5
  logger = logging.getLogger(__name__)
6
 
@@ -14,6 +16,7 @@ class ResponseGenerator:
14
  cache_folder (str, optional): Directory to cache model files (default: None).
15
  """
16
  logger.info(f"Initializing ResponseGenerator with model: {model_name}, cache_folder: {cache_folder}")
 
17
  try:
18
  # Log cache contents for debugging
19
  if cache_folder and os.path.exists(cache_folder):
@@ -24,11 +27,14 @@ class ResponseGenerator:
24
  cache_dir=cache_folder,
25
  local_files_only=True
26
  )
 
 
27
  self.model = AutoModelForCausalLM.from_pretrained(
28
  model_name,
29
  cache_dir=cache_folder,
30
  local_files_only=True
31
  )
 
32
  except Exception as e:
33
  logger.error(f"Failed to load transformer model: {str(e)}")
34
  raise
 
1
  import logging
2
+ import os
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
+ import time
6
 
7
  logger = logging.getLogger(__name__)
8
 
 
16
  cache_folder (str, optional): Directory to cache model files (default: None).
17
  """
18
  logger.info(f"Initializing ResponseGenerator with model: {model_name}, cache_folder: {cache_folder}")
19
+ start_time = time.time()
20
  try:
21
  # Log cache contents for debugging
22
  if cache_folder and os.path.exists(cache_folder):
 
27
  cache_dir=cache_folder,
28
  local_files_only=True
29
  )
30
+ logger.info(f"Tokenizer loaded in {time.time() - start_time:.2f} seconds")
31
+ start_time = time.time()
32
  self.model = AutoModelForCausalLM.from_pretrained(
33
  model_name,
34
  cache_dir=cache_folder,
35
  local_files_only=True
36
  )
37
+ logger.info(f"Model loaded in {time.time() - start_time:.2f} seconds")
38
  except Exception as e:
39
  logger.error(f"Failed to load transformer model: {str(e)}")
40
  raise