yipengsun commited on
Commit
700aa8b
·
verified ·
1 Parent(s): 67651ce

Upload models/medgemma_client.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/medgemma_client.py +42 -8
models/medgemma_client.py CHANGED
@@ -14,6 +14,7 @@ from config import (
14
  USE_27B, QUANTIZE_4B, HF_TOKEN, DEVICE,
15
  MEDGEMMA_4B_MODEL_ID, MEDGEMMA_27B_MODEL_ID,
16
  MAX_NEW_TOKENS_4B, MAX_NEW_TOKENS_27B, TEMPERATURE, REPETITION_PENALTY,
 
17
  )
18
  from models.utils import strip_thinking_tokens, resize_for_medgemma, apply_prompt_repetition
19
 
@@ -68,9 +69,18 @@ def load_4b():
68
  from transformers import AutoModelForImageTextToText, AutoProcessor
69
 
70
  is_local = _is_local_path(MEDGEMMA_4B_MODEL_ID)
 
 
 
 
 
 
 
 
 
71
  logger.info(
72
  "Loading MedGemma 4B-IT (%s) from %s...",
73
- "4-bit" if QUANTIZE_4B else "bf16",
74
  "local" if is_local else "HF Hub",
75
  )
76
 
@@ -82,9 +92,19 @@ def load_4b():
82
  else:
83
  kwargs["dtype"] = torch.bfloat16
84
 
 
 
 
 
85
  _processor_4b = AutoProcessor.from_pretrained(MEDGEMMA_4B_MODEL_ID, **_token_arg(MEDGEMMA_4B_MODEL_ID))
86
  _model_4b = AutoModelForImageTextToText.from_pretrained(MEDGEMMA_4B_MODEL_ID, **kwargs)
87
  _model_4b.eval()
 
 
 
 
 
 
88
  logger.info("MedGemma 4B loaded.")
89
  return _model_4b, _processor_4b
90
 
@@ -103,19 +123,33 @@ def load_27b():
103
  from transformers import AutoModelForCausalLM, AutoTokenizer
104
 
105
  is_local = _is_local_path(MEDGEMMA_27B_MODEL_ID)
 
 
 
 
 
106
  logger.info(
107
- "Loading MedGemma 27B Text-IT (bf16) from %s...",
 
108
  "local" if is_local else "HF Hub",
109
  )
110
 
111
- _tokenizer_27b = AutoTokenizer.from_pretrained(MEDGEMMA_27B_MODEL_ID, **_token_arg(MEDGEMMA_27B_MODEL_ID))
112
- _model_27b = AutoModelForCausalLM.from_pretrained(
113
- MEDGEMMA_27B_MODEL_ID,
114
  **_token_arg(MEDGEMMA_27B_MODEL_ID),
115
- dtype=torch.bfloat16,
116
- device_map="auto",
117
- )
 
 
 
 
 
118
  _model_27b.eval()
 
 
 
 
 
119
  logger.info("MedGemma 27B loaded.")
120
  return _model_27b, _tokenizer_27b
121
 
 
14
  USE_27B, QUANTIZE_4B, HF_TOKEN, DEVICE,
15
  MEDGEMMA_4B_MODEL_ID, MEDGEMMA_27B_MODEL_ID,
16
  MAX_NEW_TOKENS_4B, MAX_NEW_TOKENS_27B, TEMPERATURE, REPETITION_PENALTY,
17
+ ENABLE_TORCH_COMPILE, ENABLE_SDPA,
18
  )
19
  from models.utils import strip_thinking_tokens, resize_for_medgemma, apply_prompt_repetition
20
 
 
69
  from transformers import AutoModelForImageTextToText, AutoProcessor
70
 
71
  is_local = _is_local_path(MEDGEMMA_4B_MODEL_ID)
72
+ opts = []
73
+ if QUANTIZE_4B:
74
+ opts.append("4-bit")
75
+ else:
76
+ opts.append("bf16")
77
+ if ENABLE_SDPA:
78
+ opts.append("SDPA")
79
+ if ENABLE_TORCH_COMPILE:
80
+ opts.append("compiled")
81
  logger.info(
82
  "Loading MedGemma 4B-IT (%s) from %s...",
83
+ "+".join(opts),
84
  "local" if is_local else "HF Hub",
85
  )
86
 
 
92
  else:
93
  kwargs["dtype"] = torch.bfloat16
94
 
95
+ # SDPA: 优化注意力计算
96
+ if ENABLE_SDPA:
97
+ kwargs["attn_implementation"] = "sdpa"
98
+
99
  _processor_4b = AutoProcessor.from_pretrained(MEDGEMMA_4B_MODEL_ID, **_token_arg(MEDGEMMA_4B_MODEL_ID))
100
  _model_4b = AutoModelForImageTextToText.from_pretrained(MEDGEMMA_4B_MODEL_ID, **kwargs)
101
  _model_4b.eval()
102
+
103
+ # torch.compile: JIT 编译加速(首次推理会编译,耐心等待)
104
+ if ENABLE_TORCH_COMPILE:
105
+ logger.info("Compiling model with torch.compile (first inference will be slow)...")
106
+ _model_4b = torch.compile(_model_4b, mode="reduce-overhead")
107
+
108
  logger.info("MedGemma 4B loaded.")
109
  return _model_4b, _processor_4b
110
 
 
123
  from transformers import AutoModelForCausalLM, AutoTokenizer
124
 
125
  is_local = _is_local_path(MEDGEMMA_27B_MODEL_ID)
126
+ opts = ["bf16"]
127
+ if ENABLE_SDPA:
128
+ opts.append("SDPA")
129
+ if ENABLE_TORCH_COMPILE:
130
+ opts.append("compiled")
131
  logger.info(
132
+ "Loading MedGemma 27B Text-IT (%s) from %s...",
133
+ "+".join(opts),
134
  "local" if is_local else "HF Hub",
135
  )
136
 
137
+ kwargs = {
 
 
138
  **_token_arg(MEDGEMMA_27B_MODEL_ID),
139
+ "torch_dtype": torch.bfloat16,
140
+ "device_map": "auto",
141
+ }
142
+ if ENABLE_SDPA:
143
+ kwargs["attn_implementation"] = "sdpa"
144
+
145
+ _tokenizer_27b = AutoTokenizer.from_pretrained(MEDGEMMA_27B_MODEL_ID, **_token_arg(MEDGEMMA_27B_MODEL_ID))
146
+ _model_27b = AutoModelForCausalLM.from_pretrained(MEDGEMMA_27B_MODEL_ID, **kwargs)
147
  _model_27b.eval()
148
+
149
+ if ENABLE_TORCH_COMPILE:
150
+ logger.info("Compiling model with torch.compile (first inference will be slow)...")
151
+ _model_27b = torch.compile(_model_27b, mode="reduce-overhead")
152
+
153
  logger.info("MedGemma 27B loaded.")
154
  return _model_27b, _tokenizer_27b
155