autumnssuns commited on
Commit
e4b3020
·
1 Parent(s): 21bfda5

✨ Implement lazy loading for models and correct tokens counting

Browse files
models/gemma4_e2b.py CHANGED
@@ -2,21 +2,31 @@ from typing import Any
2
 
3
  import torch
4
  from transformers import AutoProcessor, AutoModelForCausalLM, TextStreamer
5
- from . import Model
 
6
 
7
  MODEL_ID = Model.GEMMA_4_E2B.model_id
 
8
 
9
- # Load model
10
- processor = AutoProcessor.from_pretrained(MODEL_ID)
11
 
12
- model = AutoModelForCausalLM.from_pretrained(
13
- MODEL_ID, torch_dtype="auto", device_map="auto"
14
- )
15
 
16
- print(f"{MODEL_ID} loaded successfully.")
17
- print(f"Model device: {model.device}")
 
 
 
18
 
19
 
 
 
 
 
 
 
 
 
20
  def generate(
21
  messages: list[dict[str, str]],
22
  max_tokens: int = 512,
@@ -24,7 +34,9 @@ def generate(
24
  top_p: float = 0.9,
25
  stop: list[str] | None = None,
26
  ) -> dict[str, Any]:
27
- print(f"Generating with {MODEL_ID}...")
 
 
28
 
29
  # Process input
30
  text = processor.apply_chat_template(
@@ -52,9 +64,13 @@ def generate(
52
 
53
  response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
54
  content = processor.parse_response(response)
 
 
55
 
56
- prompt_tokens = sum(len(msg["content"].split()) for msg in messages)
57
- completion_tokens = len(content.split())
 
 
58
 
59
  print(
60
  f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
 
2
 
3
  import torch
4
  from transformers import AutoProcessor, AutoModelForCausalLM, TextStreamer
5
+ from . import config, Model
6
+ from .lazy_model import LazyModel
7
 
8
  MODEL_ID = Model.GEMMA_4_E2B.model_id
9
+ lazy = LazyModel(MODEL_ID)
10
 
11
+ processor = None
12
+ model = None
13
 
 
 
 
14
 
15
+ @lazy.unload()
16
+ def clean_up():
17
+ global processor, model
18
+ del processor
19
+ del model
20
 
21
 
22
+ @lazy.load()
23
+ def load():
24
+ global processor, model
25
+ processor = AutoProcessor.from_pretrained(MODEL_ID, **config.tokenizer_config)
26
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **config.model_config)
27
+
28
+
29
+ @lazy.entry()
30
  def generate(
31
  messages: list[dict[str, str]],
32
  max_tokens: int = 512,
 
34
  top_p: float = 0.9,
35
  stop: list[str] | None = None,
36
  ) -> dict[str, Any]:
37
+ global processor, model
38
+ assert processor is not None, "Processor is not initialized."
39
+ assert model is not None, "Model is not loaded."
40
 
41
  # Process input
42
  text = processor.apply_chat_template(
 
64
 
65
  response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
66
  content = processor.parse_response(response)
67
+ if isinstance(content, dict) and "content" in content:
68
+ content = content["content"]
69
 
70
+ prompt_tokens = len(processor.tokenizer.apply_chat_template(messages))
71
+ completion_tokens = len(
72
+ processor.tokenizer.encode(content, add_special_tokens=False)
73
+ )
74
 
75
  print(
76
  f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
models/lazy_model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import gc
3
+ import torch
4
+ import os
5
+
6
+ LAZY_LOAD_ENABLED = os.getenv("LAZY_LOAD", "false").lower() == "true"
7
+
8
+
9
+ class LazyModel:
10
+ unload_func = None
11
+ init_func: Callable | None = None
12
+ is_loaded = False
13
+
14
+ def __init__(self, model_id: str):
15
+ self.model_id = model_id
16
+
17
+ def load(self):
18
+ def decorator(init_func):
19
+ if not LAZY_LOAD_ENABLED:
20
+ # Even if eager loading, the model should only be initialized once.
21
+ if not self.is_loaded:
22
+ init_func()
23
+ self.is_loaded = True
24
+ self.init_func = init_func
25
+ return init_func
26
+
27
+ def wrapper():
28
+ global current_model
29
+ if current_model is not None and current_model != self.model_id:
30
+ print(
31
+ f"Unloading currently loaded model '{current_model}' before loading '{self.model_id}'..."
32
+ )
33
+ _unload()
34
+
35
+ if current_model == self.model_id and self.is_loaded:
36
+ print(
37
+ f"Model '{self.model_id}' is already loaded. Skipping initialization."
38
+ )
39
+ return
40
+
41
+ print(f"Loading model '{self.model_id}'...")
42
+ init_func()
43
+ self.is_loaded = True
44
+ current_model = self
45
+ print(f"Model '{self.model_id}' loaded successfully.")
46
+
47
+ # Ensure the init_func also loads lazily
48
+ self.init_func = wrapper
49
+ return wrapper
50
+
51
+ return decorator
52
+
53
+ def unload(self):
54
+ # Create a decorator to set the unload callback function for this model. This allows the lazy loading mechanism to call the specified function when unloading the model, ensuring proper cleanup of resources.
55
+ def decorator(func):
56
+ self.unload_func = func
57
+ return func
58
+
59
+ return decorator
60
+
61
+ def entry(self):
62
+ def decorator(func):
63
+ def wrapper(*args, **kwargs):
64
+ if not self.init_func:
65
+ raise RuntimeError(
66
+ f"Model '{self.model_id}' does not have an initialization function defined."
67
+ )
68
+
69
+ # Ensure the model is loaded before executing the main function
70
+ if self.init_func and not self.is_loaded:
71
+ print(f"Model '{self.model_id}' is not loaded. Loading now...")
72
+ self.init_func()
73
+
74
+ print(f"Executing main function for model '{self.model_id}'...")
75
+ return func(*args, **kwargs)
76
+
77
+ return wrapper
78
+
79
+ return decorator
80
+
81
+
82
+ def _unload():
83
+ global current_model
84
+ if current_model and current_model.unload_func:
85
+ current_model.unload_func()
86
+ current_model = None
87
+ # Ensure garbage collection and CUDA cache clearing
88
+ gc.collect()
89
+ if torch.cuda.is_available():
90
+ torch.cuda.empty_cache()
91
+
92
+
93
+ # Global variaable to keep track of the currently loaded LazyModel instance. This allows the lazy loading mechanism to determine if a model is already loaded and manage unloading of other models when necessary.
94
+ current_model: LazyModel | None = None
models/{llama.py → llama3_2_3b_instruct.py} RENAMED
@@ -2,22 +2,39 @@ from typing import Any
2
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
4
  from . import config, Model
 
5
 
6
  MODEL_ID = Model.LLAMA_3_2_3B_INSTRUCT.model_id
 
7
 
8
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **config.model_config)
9
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **config.tokenizer_config)
 
10
 
11
- pipe = pipeline(
12
- "text-generation",
13
- model=model,
14
- tokenizer=tokenizer,
15
- **config.pipeline_config,
16
- )
17
- print(f"{MODEL_ID} loaded successfully.")
18
- print(f"Model device: {pipe.model.device}")
19
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def generate(
22
  messages: list[dict[str, str]],
23
  max_tokens: int = 512,
@@ -25,6 +42,8 @@ def generate(
25
  top_p: float = 0.9,
26
  stop: list[str] | None = None,
27
  ) -> dict[str, Any]:
 
 
28
  assert pipe.tokenizer is not None, "Tokenizer is not loaded."
29
 
30
  print(f"Generating with {MODEL_ID}...")
@@ -40,8 +59,8 @@ def generate(
40
  )
41
  content = outputs[0]["generated_text"][-1]["content"]
42
 
43
- prompt_tokens = sum(len(msg["content"].split()) for msg in messages)
44
- completion_tokens = len(content.split())
45
 
46
  print(
47
  f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
 
2
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
4
  from . import config, Model
5
+ from .lazy_model import LazyModel
6
 
7
  MODEL_ID = Model.LLAMA_3_2_3B_INSTRUCT.model_id
8
+ lazy = LazyModel(MODEL_ID)
9
 
10
+ model = None
11
+ tokenizer = None
12
+ pipe = None
13
 
 
 
 
 
 
 
 
 
14
 
15
+ @lazy.unload()
16
+ def clean_up():
17
+ global model, tokenizer, pipe
18
+ del model
19
+ del tokenizer
20
+ del pipe
21
 
22
+
23
+ @lazy.load()
24
+ def init():
25
+ global model, tokenizer, pipe
26
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **config.model_config)
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **config.tokenizer_config)
28
+
29
+ pipe = pipeline(
30
+ "text-generation",
31
+ model=model,
32
+ tokenizer=tokenizer,
33
+ **config.pipeline_config,
34
+ )
35
+
36
+
37
+ @lazy.entry()
38
  def generate(
39
  messages: list[dict[str, str]],
40
  max_tokens: int = 512,
 
42
  top_p: float = 0.9,
43
  stop: list[str] | None = None,
44
  ) -> dict[str, Any]:
45
+ global model, tokenizer, pipe
46
+ assert pipe is not None, "Pipeline is not initialized."
47
  assert pipe.tokenizer is not None, "Tokenizer is not loaded."
48
 
49
  print(f"Generating with {MODEL_ID}...")
 
59
  )
60
  content = outputs[0]["generated_text"][-1]["content"]
61
 
62
+ prompt_tokens = len(pipe.tokenizer.apply_chat_template(messages))
63
+ completion_tokens = len(pipe.tokenizer.encode(content, add_special_tokens=False))
64
 
65
  print(
66
  f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
service.py CHANGED
@@ -11,28 +11,21 @@ def generate(
11
  top_p: float = 0.9,
12
  stop: list[str] | None = None,
13
  ) -> dict[str, Any]:
 
 
 
 
14
  if model == Model.LLAMA_3_2_3B_INSTRUCT.model_id:
15
- from models import llama
16
-
17
- return llama.generate(
18
- messages=messages,
19
- max_tokens=max_tokens,
20
- temperature=temperature,
21
- top_p=top_p,
22
- stop=stop,
23
- )
24
  if model == Model.GEMMA_4_E2B.model_id:
25
- from models import gemma4_e2b
26
-
27
- return gemma4_e2b.generate(
28
- messages=messages,
29
- max_tokens=max_tokens,
30
- temperature=temperature,
31
- top_p=top_p,
32
- stop=stop,
33
- )
34
- msg = f"Unsupported model: {model}"
35
- raise ValueError(msg)
36
 
37
 
38
  def list_models() -> dict[str, list[dict[str, Any]]]:
 
11
  top_p: float = 0.9,
12
  stop: list[str] | None = None,
13
  ) -> dict[str, Any]:
14
+ # Ensure model exists
15
+ if model not in [m["id"] for m in get_available_models()]:
16
+ msg = f"Model '{model}' is not available. Supported models: {[m['id'] for m in get_available_models()]}"
17
+ raise ValueError(msg)
18
  if model == Model.LLAMA_3_2_3B_INSTRUCT.model_id:
19
+ from models.llama3_2_3b_instruct import generate
 
 
 
 
 
 
 
 
20
  if model == Model.GEMMA_4_E2B.model_id:
21
+ from models.gemma4_e2b import generate
22
+ return generate( # type: ignore
23
+ messages=messages,
24
+ max_tokens=max_tokens,
25
+ temperature=temperature,
26
+ top_p=top_p,
27
+ stop=stop,
28
+ )
 
 
 
29
 
30
 
31
  def list_models() -> dict[str, list[dict[str, Any]]]: