Spaces:
Sleeping
Sleeping
Commit ·
e4b3020
1
Parent(s): 21bfda5
✨ Implement lazy loading for models and correct tokens counting
Browse files- models/gemma4_e2b.py +27 -11
- models/lazy_model.py +94 -0
- models/{llama.py → llama3_2_3b_instruct.py} +31 -12
- service.py +13 -20
models/gemma4_e2b.py
CHANGED
|
@@ -2,21 +2,31 @@ from typing import Any
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
from transformers import AutoProcessor, AutoModelForCausalLM, TextStreamer
|
| 5 |
-
from . import Model
|
|
|
|
| 6 |
|
| 7 |
MODEL_ID = Model.GEMMA_4_E2B.model_id
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
|
| 12 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 13 |
-
MODEL_ID, torch_dtype="auto", device_map="auto"
|
| 14 |
-
)
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def generate(
|
| 21 |
messages: list[dict[str, str]],
|
| 22 |
max_tokens: int = 512,
|
|
@@ -24,7 +34,9 @@ def generate(
|
|
| 24 |
top_p: float = 0.9,
|
| 25 |
stop: list[str] | None = None,
|
| 26 |
) -> dict[str, Any]:
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Process input
|
| 30 |
text = processor.apply_chat_template(
|
|
@@ -52,9 +64,13 @@ def generate(
|
|
| 52 |
|
| 53 |
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
| 54 |
content = processor.parse_response(response)
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
prompt_tokens =
|
| 57 |
-
completion_tokens = len(
|
|
|
|
|
|
|
| 58 |
|
| 59 |
print(
|
| 60 |
f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
from transformers import AutoProcessor, AutoModelForCausalLM, TextStreamer
|
| 5 |
+
from . import config, Model
|
| 6 |
+
from .lazy_model import LazyModel
|
| 7 |
|
| 8 |
MODEL_ID = Model.GEMMA_4_E2B.model_id
|
| 9 |
+
lazy = LazyModel(MODEL_ID)
|
| 10 |
|
| 11 |
+
processor = None
|
| 12 |
+
model = None
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
@lazy.unload()
|
| 16 |
+
def clean_up():
|
| 17 |
+
global processor, model
|
| 18 |
+
del processor
|
| 19 |
+
del model
|
| 20 |
|
| 21 |
|
| 22 |
+
@lazy.load()
|
| 23 |
+
def load():
|
| 24 |
+
global processor, model
|
| 25 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID, **config.tokenizer_config)
|
| 26 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **config.model_config)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@lazy.entry()
|
| 30 |
def generate(
|
| 31 |
messages: list[dict[str, str]],
|
| 32 |
max_tokens: int = 512,
|
|
|
|
| 34 |
top_p: float = 0.9,
|
| 35 |
stop: list[str] | None = None,
|
| 36 |
) -> dict[str, Any]:
|
| 37 |
+
global processor, model
|
| 38 |
+
assert processor is not None, "Processor is not initialized."
|
| 39 |
+
assert model is not None, "Model is not loaded."
|
| 40 |
|
| 41 |
# Process input
|
| 42 |
text = processor.apply_chat_template(
|
|
|
|
| 64 |
|
| 65 |
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
|
| 66 |
content = processor.parse_response(response)
|
| 67 |
+
if isinstance(content, dict) and "content" in content:
|
| 68 |
+
content = content["content"]
|
| 69 |
|
| 70 |
+
prompt_tokens = len(processor.tokenizer.apply_chat_template(messages))
|
| 71 |
+
completion_tokens = len(
|
| 72 |
+
processor.tokenizer.encode(content, add_special_tokens=False)
|
| 73 |
+
)
|
| 74 |
|
| 75 |
print(
|
| 76 |
f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
|
models/lazy_model.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
LAZY_LOAD_ENABLED = os.getenv("LAZY_LOAD", "false").lower() == "true"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LazyModel:
|
| 10 |
+
unload_func = None
|
| 11 |
+
init_func: Callable | None = None
|
| 12 |
+
is_loaded = False
|
| 13 |
+
|
| 14 |
+
def __init__(self, model_id: str):
|
| 15 |
+
self.model_id = model_id
|
| 16 |
+
|
| 17 |
+
def load(self):
|
| 18 |
+
def decorator(init_func):
|
| 19 |
+
if not LAZY_LOAD_ENABLED:
|
| 20 |
+
# Even if eager loading, the model should only be initialized once.
|
| 21 |
+
if not self.is_loaded:
|
| 22 |
+
init_func()
|
| 23 |
+
self.is_loaded = True
|
| 24 |
+
self.init_func = init_func
|
| 25 |
+
return init_func
|
| 26 |
+
|
| 27 |
+
def wrapper():
|
| 28 |
+
global current_model
|
| 29 |
+
if current_model is not None and current_model != self.model_id:
|
| 30 |
+
print(
|
| 31 |
+
f"Unloading currently loaded model '{current_model}' before loading '{self.model_id}'..."
|
| 32 |
+
)
|
| 33 |
+
_unload()
|
| 34 |
+
|
| 35 |
+
if current_model == self.model_id and self.is_loaded:
|
| 36 |
+
print(
|
| 37 |
+
f"Model '{self.model_id}' is already loaded. Skipping initialization."
|
| 38 |
+
)
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
print(f"Loading model '{self.model_id}'...")
|
| 42 |
+
init_func()
|
| 43 |
+
self.is_loaded = True
|
| 44 |
+
current_model = self
|
| 45 |
+
print(f"Model '{self.model_id}' loaded successfully.")
|
| 46 |
+
|
| 47 |
+
# Ensure the init_func also loads lazily
|
| 48 |
+
self.init_func = wrapper
|
| 49 |
+
return wrapper
|
| 50 |
+
|
| 51 |
+
return decorator
|
| 52 |
+
|
| 53 |
+
def unload(self):
|
| 54 |
+
# Create a decorator to set the unload callback function for this model. This allows the lazy loading mechanism to call the specified function when unloading the model, ensuring proper cleanup of resources.
|
| 55 |
+
def decorator(func):
|
| 56 |
+
self.unload_func = func
|
| 57 |
+
return func
|
| 58 |
+
|
| 59 |
+
return decorator
|
| 60 |
+
|
| 61 |
+
def entry(self):
|
| 62 |
+
def decorator(func):
|
| 63 |
+
def wrapper(*args, **kwargs):
|
| 64 |
+
if not self.init_func:
|
| 65 |
+
raise RuntimeError(
|
| 66 |
+
f"Model '{self.model_id}' does not have an initialization function defined."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Ensure the model is loaded before executing the main function
|
| 70 |
+
if self.init_func and not self.is_loaded:
|
| 71 |
+
print(f"Model '{self.model_id}' is not loaded. Loading now...")
|
| 72 |
+
self.init_func()
|
| 73 |
+
|
| 74 |
+
print(f"Executing main function for model '{self.model_id}'...")
|
| 75 |
+
return func(*args, **kwargs)
|
| 76 |
+
|
| 77 |
+
return wrapper
|
| 78 |
+
|
| 79 |
+
return decorator
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _unload():
|
| 83 |
+
global current_model
|
| 84 |
+
if current_model and current_model.unload_func:
|
| 85 |
+
current_model.unload_func()
|
| 86 |
+
current_model = None
|
| 87 |
+
# Ensure garbage collection and CUDA cache clearing
|
| 88 |
+
gc.collect()
|
| 89 |
+
if torch.cuda.is_available():
|
| 90 |
+
torch.cuda.empty_cache()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Global variaable to keep track of the currently loaded LazyModel instance. This allows the lazy loading mechanism to determine if a model is already loaded and manage unloading of other models when necessary.
|
| 94 |
+
current_model: LazyModel | None = None
|
models/{llama.py → llama3_2_3b_instruct.py}
RENAMED
|
@@ -2,22 +2,39 @@ from typing import Any
|
|
| 2 |
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
|
| 4 |
from . import config, Model
|
|
|
|
| 5 |
|
| 6 |
MODEL_ID = Model.LLAMA_3_2_3B_INSTRUCT.model_id
|
|
|
|
| 7 |
|
| 8 |
-
model =
|
| 9 |
-
tokenizer =
|
|
|
|
| 10 |
|
| 11 |
-
pipe = pipeline(
|
| 12 |
-
"text-generation",
|
| 13 |
-
model=model,
|
| 14 |
-
tokenizer=tokenizer,
|
| 15 |
-
**config.pipeline_config,
|
| 16 |
-
)
|
| 17 |
-
print(f"{MODEL_ID} loaded successfully.")
|
| 18 |
-
print(f"Model device: {pipe.model.device}")
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def generate(
|
| 22 |
messages: list[dict[str, str]],
|
| 23 |
max_tokens: int = 512,
|
|
@@ -25,6 +42,8 @@ def generate(
|
|
| 25 |
top_p: float = 0.9,
|
| 26 |
stop: list[str] | None = None,
|
| 27 |
) -> dict[str, Any]:
|
|
|
|
|
|
|
| 28 |
assert pipe.tokenizer is not None, "Tokenizer is not loaded."
|
| 29 |
|
| 30 |
print(f"Generating with {MODEL_ID}...")
|
|
@@ -40,8 +59,8 @@ def generate(
|
|
| 40 |
)
|
| 41 |
content = outputs[0]["generated_text"][-1]["content"]
|
| 42 |
|
| 43 |
-
prompt_tokens =
|
| 44 |
-
completion_tokens = len(
|
| 45 |
|
| 46 |
print(
|
| 47 |
f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
|
|
|
|
| 2 |
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
|
| 4 |
from . import config, Model
|
| 5 |
+
from .lazy_model import LazyModel
|
| 6 |
|
| 7 |
MODEL_ID = Model.LLAMA_3_2_3B_INSTRUCT.model_id
|
| 8 |
+
lazy = LazyModel(MODEL_ID)
|
| 9 |
|
| 10 |
+
model = None
|
| 11 |
+
tokenizer = None
|
| 12 |
+
pipe = None
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
@lazy.unload()
|
| 16 |
+
def clean_up():
|
| 17 |
+
global model, tokenizer, pipe
|
| 18 |
+
del model
|
| 19 |
+
del tokenizer
|
| 20 |
+
del pipe
|
| 21 |
|
| 22 |
+
|
| 23 |
+
@lazy.load()
|
| 24 |
+
def init():
|
| 25 |
+
global model, tokenizer, pipe
|
| 26 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **config.model_config)
|
| 27 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **config.tokenizer_config)
|
| 28 |
+
|
| 29 |
+
pipe = pipeline(
|
| 30 |
+
"text-generation",
|
| 31 |
+
model=model,
|
| 32 |
+
tokenizer=tokenizer,
|
| 33 |
+
**config.pipeline_config,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@lazy.entry()
|
| 38 |
def generate(
|
| 39 |
messages: list[dict[str, str]],
|
| 40 |
max_tokens: int = 512,
|
|
|
|
| 42 |
top_p: float = 0.9,
|
| 43 |
stop: list[str] | None = None,
|
| 44 |
) -> dict[str, Any]:
|
| 45 |
+
global model, tokenizer, pipe
|
| 46 |
+
assert pipe is not None, "Pipeline is not initialized."
|
| 47 |
assert pipe.tokenizer is not None, "Tokenizer is not loaded."
|
| 48 |
|
| 49 |
print(f"Generating with {MODEL_ID}...")
|
|
|
|
| 59 |
)
|
| 60 |
content = outputs[0]["generated_text"][-1]["content"]
|
| 61 |
|
| 62 |
+
prompt_tokens = len(pipe.tokenizer.apply_chat_template(messages))
|
| 63 |
+
completion_tokens = len(pipe.tokenizer.encode(content, add_special_tokens=False))
|
| 64 |
|
| 65 |
print(
|
| 66 |
f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
|
service.py
CHANGED
|
@@ -11,28 +11,21 @@ def generate(
|
|
| 11 |
top_p: float = 0.9,
|
| 12 |
stop: list[str] | None = None,
|
| 13 |
) -> dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
if model == Model.LLAMA_3_2_3B_INSTRUCT.model_id:
|
| 15 |
-
from models import
|
| 16 |
-
|
| 17 |
-
return llama.generate(
|
| 18 |
-
messages=messages,
|
| 19 |
-
max_tokens=max_tokens,
|
| 20 |
-
temperature=temperature,
|
| 21 |
-
top_p=top_p,
|
| 22 |
-
stop=stop,
|
| 23 |
-
)
|
| 24 |
if model == Model.GEMMA_4_E2B.model_id:
|
| 25 |
-
from models import
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
)
|
| 34 |
-
msg = f"Unsupported model: {model}"
|
| 35 |
-
raise ValueError(msg)
|
| 36 |
|
| 37 |
|
| 38 |
def list_models() -> dict[str, list[dict[str, Any]]]:
|
|
|
|
| 11 |
top_p: float = 0.9,
|
| 12 |
stop: list[str] | None = None,
|
| 13 |
) -> dict[str, Any]:
|
| 14 |
+
# Ensure model exists
|
| 15 |
+
if model not in [m["id"] for m in get_available_models()]:
|
| 16 |
+
msg = f"Model '{model}' is not available. Supported models: {[m['id'] for m in get_available_models()]}"
|
| 17 |
+
raise ValueError(msg)
|
| 18 |
if model == Model.LLAMA_3_2_3B_INSTRUCT.model_id:
|
| 19 |
+
from models.llama3_2_3b_instruct import generate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
if model == Model.GEMMA_4_E2B.model_id:
|
| 21 |
+
from models.gemma4_e2b import generate
|
| 22 |
+
return generate( # type: ignore
|
| 23 |
+
messages=messages,
|
| 24 |
+
max_tokens=max_tokens,
|
| 25 |
+
temperature=temperature,
|
| 26 |
+
top_p=top_p,
|
| 27 |
+
stop=stop,
|
| 28 |
+
)
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def list_models() -> dict[str, list[dict[str, Any]]]:
|