autumnssuns commited on
Commit
21bfda5
·
1 Parent(s): 22af552

✨ Add Gemma 4 E2B model integration and update service to support multiple models

Browse files
Files changed (6) hide show
  1. app.py +2 -2
  2. models/__init__.py +38 -8
  3. models/config.py +39 -0
  4. models/gemma4_e2b.py +72 -0
  5. models/llama.py +57 -88
  6. service.py +16 -5
app.py CHANGED
@@ -1,10 +1,10 @@
1
  from typing import Any
2
 
3
  import spaces
4
- from models.llama import LlamaModel
5
  import gradio
6
 
7
  from service import generate, list_models
 
8
 
9
  app = gradio.Server()
10
 
@@ -13,7 +13,7 @@ app = gradio.Server()
13
  @spaces.GPU(duration=10)
14
  def generate_endpoint(
15
  messages: list[dict[str, str]],
16
- model: str = LlamaModel.MODEL_ID,
17
  max_tokens: int = 512,
18
  temperature: float = 0.7,
19
  top_p: float = 0.9,
 
1
  from typing import Any
2
 
3
  import spaces
 
4
  import gradio
5
 
6
  from service import generate, list_models
7
+ from models import gemma4_e2b
8
 
9
  app = gradio.Server()
10
 
 
13
  @spaces.GPU(duration=10)
14
  def generate_endpoint(
15
  messages: list[dict[str, str]],
16
+ model: str = gemma4_e2b.MODEL_ID,
17
  max_tokens: int = 512,
18
  temperature: float = 0.7,
19
  top_p: float = 0.9,
models/__init__.py CHANGED
@@ -1,14 +1,44 @@
1
  from typing import Any
 
2
 
3
 
4
- AVAILABLE_MODELS: list[dict[str, Any]] = [
5
- {
6
- "id": "meta-llama/Llama-3.2-3B-Instruct",
7
- "type": "text-generation",
8
- "backend": "local",
9
- "max_tokens": 4096,
10
- },
11
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def get_available_models() -> list[dict[str, Any]]:
 
1
  from typing import Any
2
+ from enum import Enum
3
 
4
 
5
+ class Model(Enum):
6
+ LLAMA_3_2_3B_INSTRUCT = (
7
+ "meta-llama/Llama-3.2-3B-Instruct",
8
+ "text-generation",
9
+ "local",
10
+ 4096,
11
+ )
12
+ GEMMA_4_E2B = ("google/gemma-4-E2B-it", "text-generation", "local", 4096)
13
+
14
+ def __init__(
15
+ self,
16
+ model_id: str,
17
+ model_type: str,
18
+ backend: str,
19
+ max_tokens: int,
20
+ ):
21
+ self.model_id = model_id
22
+ self.model_type = model_type
23
+ self.backend = backend
24
+ self.max_tokens = max_tokens
25
+
26
+ def __str__(self):
27
+ return self.model_id
28
+
29
+ def __repr__(self):
30
+ return f"Model(id={self.model_id}, type={self.model_type}, backend={self.backend}, max_tokens={self.max_tokens})"
31
+
32
+ def to_dict(self) -> dict[str, Any]:
33
+ return {
34
+ "id": self.model_id,
35
+ "type": self.model_type,
36
+ "backend": self.backend,
37
+ "max_tokens": self.max_tokens,
38
+ }
39
+
40
+
41
+ AVAILABLE_MODELS = [model.to_dict() for model in Model]
42
 
43
 
44
  def get_available_models() -> list[dict[str, Any]]:
models/config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Common configuration for all models, including device and dtype settings.
2
+
3
+ import os
4
+ import torch
5
+
6
+ TOKEN = os.getenv("HF_TOKEN")
7
+ QUANTIZE_4_BIT = os.getenv("QUANTIZE_4_BIT", "false").lower() == "true"
8
+
9
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32
11
+
12
+ print(f"Using {torch_device} with dtype {torch_dtype}...")
13
+
14
+ model_config = {
15
+ "torch_dtype": torch_dtype,
16
+ "device_map": torch_device,
17
+ "token": TOKEN,
18
+ }
19
+
20
+ tokenizer_config = {
21
+ "token": TOKEN,
22
+ }
23
+
24
+ pipeline_config = {
25
+ "torch_dtype": torch_dtype,
26
+ "device_map": "auto",
27
+ }
28
+
29
+
30
+ def enable_quantization():
31
+ print("Enabling 4-bit quantization for compatible models...")
32
+ from transformers import BitsAndBytesConfig
33
+
34
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
35
+ model_config["quantization_config"] = quantization_config
36
+
37
+
38
+ if QUANTIZE_4_BIT:
39
+ enable_quantization()
models/gemma4_e2b.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ from transformers import AutoProcessor, AutoModelForCausalLM, TextStreamer
5
+ from . import Model
6
+
7
+ MODEL_ID = Model.GEMMA_4_E2B.model_id
8
+
9
+ # Load model
10
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
11
+
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_ID, torch_dtype="auto", device_map="auto"
14
+ )
15
+
16
+ print(f"{MODEL_ID} loaded successfully.")
17
+ print(f"Model device: {model.device}")
18
+
19
+
20
+ def generate(
21
+ messages: list[dict[str, str]],
22
+ max_tokens: int = 512,
23
+ temperature: float = 0.7,
24
+ top_p: float = 0.9,
25
+ stop: list[str] | None = None,
26
+ ) -> dict[str, Any]:
27
+ print(f"Generating with {MODEL_ID}...")
28
+
29
+ # Process input
30
+ text = processor.apply_chat_template(
31
+ messages,
32
+ tokenize=False,
33
+ add_generation_prompt=True,
34
+ enable_thinking=False,
35
+ )
36
+ inputs = processor(text=text, return_tensors="pt").to(model.device)
37
+ input_len = inputs["input_ids"].shape[-1]
38
+
39
+ streamer = TextStreamer(
40
+ processor.tokenizer, skip_prompt=True, skip_special_tokens=True
41
+ )
42
+
43
+ with torch.inference_mode():
44
+ outputs = model.generate( # type: ignore
45
+ **inputs,
46
+ max_new_tokens=max_tokens,
47
+ temperature=temperature,
48
+ top_p=top_p,
49
+ do_sample=temperature > 0,
50
+ streamer=streamer,
51
+ )
52
+
53
+ response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
54
+ content = processor.parse_response(response)
55
+
56
+ prompt_tokens = sum(len(msg["content"].split()) for msg in messages)
57
+ completion_tokens = len(content.split())
58
+
59
+ print(
60
+ f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
61
+ )
62
+ print(f"Generated content: {content}")
63
+
64
+ return {
65
+ "model": MODEL_ID,
66
+ "content": content,
67
+ "usage": {
68
+ "prompt_tokens": prompt_tokens,
69
+ "completion_tokens": completion_tokens,
70
+ "total_tokens": prompt_tokens + completion_tokens,
71
+ },
72
+ }
models/llama.py CHANGED
@@ -1,90 +1,59 @@
1
  from typing import Any
2
 
3
- import spaces
4
- import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
-
7
-
8
- class LlamaModel:
9
- _instance: "LlamaModel | None" = None
10
- _pipe: Any = None
11
-
12
- MODEL_ID: str = "meta-llama/Llama-3.2-3B-Instruct"
13
-
14
- @classmethod
15
- def get_instance(cls) -> "LlamaModel":
16
- if cls._instance is None:
17
- cls._instance = cls()
18
- return cls._instance
19
-
20
- def __init__(self) -> None:
21
- if LlamaModel._pipe is not None:
22
- return
23
-
24
- torch_device = "cuda" if torch.cuda.is_available() else "cpu"
25
- torch_dtype = (
26
- torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32
27
- )
28
-
29
- model = AutoModelForCausalLM.from_pretrained(
30
- self.MODEL_ID,
31
- torch_dtype=torch_dtype,
32
- device_map=torch_device,
33
- )
34
- tokenizer = AutoTokenizer.from_pretrained(self.MODEL_ID)
35
-
36
- LlamaModel._pipe = pipeline(
37
- "text-generation",
38
- model=model,
39
- tokenizer=tokenizer,
40
- torch_dtype=torch_dtype,
41
- device_map="auto",
42
- )
43
-
44
- @staticmethod
45
- def get_pipe() -> Any:
46
- if LlamaModel._pipe is None:
47
- LlamaModel.get_instance()
48
- return LlamaModel._pipe
49
-
50
- @staticmethod
51
- def generate(
52
- messages: list[dict[str, str]],
53
- max_tokens: int = 512,
54
- temperature: float = 0.7,
55
- top_p: float = 0.9,
56
- stop: list[str] | None = None,
57
- ) -> dict[str, Any]:
58
- print(f"Generating with {LlamaModel.MODEL_ID}...")
59
- pipe = LlamaModel.get_pipe()
60
- outputs = pipe(
61
- messages,
62
- max_new_tokens=max_tokens,
63
- temperature=temperature,
64
- top_p=top_p,
65
- do_sample=temperature > 0,
66
- )
67
- content = outputs[0]["generated_text"][-1]["content"]
68
-
69
- prompt_tokens = sum(len(msg["content"].split()) for msg in messages)
70
- completion_tokens = len(content.split())
71
-
72
- print(
73
- f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
74
- )
75
- print(f"Generated content: {content}")
76
-
77
- return {
78
- "model": LlamaModel.MODEL_ID,
79
- "content": content,
80
- "usage": {
81
- "prompt_tokens": prompt_tokens,
82
- "completion_tokens": completion_tokens,
83
- "total_tokens": prompt_tokens + completion_tokens,
84
- },
85
- }
86
-
87
-
88
- # Load the model immediately
89
- LlamaModel.get_instance()
90
- print(f"{LlamaModel.MODEL_ID} loaded and ready to generate.")
 
1
  from typing import Any
2
 
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
4
+ from . import config, Model
5
+
6
+ MODEL_ID = Model.LLAMA_3_2_3B_INSTRUCT.model_id
7
+
8
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **config.model_config)
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, **config.tokenizer_config)
10
+
11
+ pipe = pipeline(
12
+ "text-generation",
13
+ model=model,
14
+ tokenizer=tokenizer,
15
+ **config.pipeline_config,
16
+ )
17
+ print(f"{MODEL_ID} loaded successfully.")
18
+ print(f"Model device: {pipe.model.device}")
19
+
20
+
21
+ def generate(
22
+ messages: list[dict[str, str]],
23
+ max_tokens: int = 512,
24
+ temperature: float = 0.7,
25
+ top_p: float = 0.9,
26
+ stop: list[str] | None = None,
27
+ ) -> dict[str, Any]:
28
+ assert pipe.tokenizer is not None, "Tokenizer is not loaded."
29
+
30
+ print(f"Generating with {MODEL_ID}...")
31
+ streamer = TextStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True)
32
+ outputs = pipe(
33
+ messages,
34
+ max_new_tokens=max_tokens,
35
+ temperature=temperature,
36
+ top_p=top_p,
37
+ do_sample=temperature > 0,
38
+ # Enable streaming output to console
39
+ streamer=streamer,
40
+ )
41
+ content = outputs[0]["generated_text"][-1]["content"]
42
+
43
+ prompt_tokens = sum(len(msg["content"].split()) for msg in messages)
44
+ completion_tokens = len(content.split())
45
+
46
+ print(
47
+ f"Generation complete. Prompt tokens: {prompt_tokens}, Completion tokens: {completion_tokens}"
48
+ )
49
+ print(f"Generated content: {content}")
50
+
51
+ return {
52
+ "model": MODEL_ID,
53
+ "content": content,
54
+ "usage": {
55
+ "prompt_tokens": prompt_tokens,
56
+ "completion_tokens": completion_tokens,
57
+ "total_tokens": prompt_tokens + completion_tokens,
58
+ },
59
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
service.py CHANGED
@@ -1,19 +1,30 @@
1
  from typing import Any
2
 
3
- from models import get_available_models
4
- from models.llama import LlamaModel
5
 
6
 
7
  def generate(
8
  messages: list[dict[str, str]],
9
- model: str = LlamaModel.MODEL_ID,
10
  max_tokens: int = 512,
11
  temperature: float = 0.7,
12
  top_p: float = 0.9,
13
  stop: list[str] | None = None,
14
  ) -> dict[str, Any]:
15
- if model == LlamaModel.MODEL_ID:
16
- return LlamaModel.generate(
 
 
 
 
 
 
 
 
 
 
 
 
17
  messages=messages,
18
  max_tokens=max_tokens,
19
  temperature=temperature,
 
1
  from typing import Any
2
 
3
+ from models import get_available_models, Model
 
4
 
5
 
6
  def generate(
7
  messages: list[dict[str, str]],
8
+ model: str,
9
  max_tokens: int = 512,
10
  temperature: float = 0.7,
11
  top_p: float = 0.9,
12
  stop: list[str] | None = None,
13
  ) -> dict[str, Any]:
14
+ if model == Model.LLAMA_3_2_3B_INSTRUCT.model_id:
15
+ from models import llama
16
+
17
+ return llama.generate(
18
+ messages=messages,
19
+ max_tokens=max_tokens,
20
+ temperature=temperature,
21
+ top_p=top_p,
22
+ stop=stop,
23
+ )
24
+ if model == Model.GEMMA_4_E2B.model_id:
25
+ from models import gemma4_e2b
26
+
27
+ return gemma4_e2b.generate(
28
  messages=messages,
29
  max_tokens=max_tokens,
30
  temperature=temperature,