Spaces:
Running
on
Zero
Running
on
Zero
Upload models/medgemma_client.py with huggingface_hub
Browse files- models/medgemma_client.py +42 -8
models/medgemma_client.py
CHANGED
|
@@ -14,6 +14,7 @@ from config import (
|
|
| 14 |
USE_27B, QUANTIZE_4B, HF_TOKEN, DEVICE,
|
| 15 |
MEDGEMMA_4B_MODEL_ID, MEDGEMMA_27B_MODEL_ID,
|
| 16 |
MAX_NEW_TOKENS_4B, MAX_NEW_TOKENS_27B, TEMPERATURE, REPETITION_PENALTY,
|
|
|
|
| 17 |
)
|
| 18 |
from models.utils import strip_thinking_tokens, resize_for_medgemma, apply_prompt_repetition
|
| 19 |
|
|
@@ -68,9 +69,18 @@ def load_4b():
|
|
| 68 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 69 |
|
| 70 |
is_local = _is_local_path(MEDGEMMA_4B_MODEL_ID)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
logger.info(
|
| 72 |
"Loading MedGemma 4B-IT (%s) from %s...",
|
| 73 |
-
"
|
| 74 |
"local" if is_local else "HF Hub",
|
| 75 |
)
|
| 76 |
|
|
@@ -82,9 +92,19 @@ def load_4b():
|
|
| 82 |
else:
|
| 83 |
kwargs["dtype"] = torch.bfloat16
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
_processor_4b = AutoProcessor.from_pretrained(MEDGEMMA_4B_MODEL_ID, **_token_arg(MEDGEMMA_4B_MODEL_ID))
|
| 86 |
_model_4b = AutoModelForImageTextToText.from_pretrained(MEDGEMMA_4B_MODEL_ID, **kwargs)
|
| 87 |
_model_4b.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
logger.info("MedGemma 4B loaded.")
|
| 89 |
return _model_4b, _processor_4b
|
| 90 |
|
|
@@ -103,19 +123,33 @@ def load_27b():
|
|
| 103 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 104 |
|
| 105 |
is_local = _is_local_path(MEDGEMMA_27B_MODEL_ID)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
logger.info(
|
| 107 |
-
"Loading MedGemma 27B Text-IT (
|
|
|
|
| 108 |
"local" if is_local else "HF Hub",
|
| 109 |
)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
_model_27b = AutoModelForCausalLM.from_pretrained(
|
| 113 |
-
MEDGEMMA_27B_MODEL_ID,
|
| 114 |
**_token_arg(MEDGEMMA_27B_MODEL_ID),
|
| 115 |
-
|
| 116 |
-
device_map
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
_model_27b.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
logger.info("MedGemma 27B loaded.")
|
| 120 |
return _model_27b, _tokenizer_27b
|
| 121 |
|
|
|
|
| 14 |
USE_27B, QUANTIZE_4B, HF_TOKEN, DEVICE,
|
| 15 |
MEDGEMMA_4B_MODEL_ID, MEDGEMMA_27B_MODEL_ID,
|
| 16 |
MAX_NEW_TOKENS_4B, MAX_NEW_TOKENS_27B, TEMPERATURE, REPETITION_PENALTY,
|
| 17 |
+
ENABLE_TORCH_COMPILE, ENABLE_SDPA,
|
| 18 |
)
|
| 19 |
from models.utils import strip_thinking_tokens, resize_for_medgemma, apply_prompt_repetition
|
| 20 |
|
|
|
|
| 69 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 70 |
|
| 71 |
is_local = _is_local_path(MEDGEMMA_4B_MODEL_ID)
|
| 72 |
+
opts = []
|
| 73 |
+
if QUANTIZE_4B:
|
| 74 |
+
opts.append("4-bit")
|
| 75 |
+
else:
|
| 76 |
+
opts.append("bf16")
|
| 77 |
+
if ENABLE_SDPA:
|
| 78 |
+
opts.append("SDPA")
|
| 79 |
+
if ENABLE_TORCH_COMPILE:
|
| 80 |
+
opts.append("compiled")
|
| 81 |
logger.info(
|
| 82 |
"Loading MedGemma 4B-IT (%s) from %s...",
|
| 83 |
+
"+".join(opts),
|
| 84 |
"local" if is_local else "HF Hub",
|
| 85 |
)
|
| 86 |
|
|
|
|
| 92 |
else:
|
| 93 |
kwargs["dtype"] = torch.bfloat16
|
| 94 |
|
| 95 |
+
# SDPA: 优化注意力计算
|
| 96 |
+
if ENABLE_SDPA:
|
| 97 |
+
kwargs["attn_implementation"] = "sdpa"
|
| 98 |
+
|
| 99 |
_processor_4b = AutoProcessor.from_pretrained(MEDGEMMA_4B_MODEL_ID, **_token_arg(MEDGEMMA_4B_MODEL_ID))
|
| 100 |
_model_4b = AutoModelForImageTextToText.from_pretrained(MEDGEMMA_4B_MODEL_ID, **kwargs)
|
| 101 |
_model_4b.eval()
|
| 102 |
+
|
| 103 |
+
# torch.compile: JIT 编译加速(首次推理会编译,耐心等待)
|
| 104 |
+
if ENABLE_TORCH_COMPILE:
|
| 105 |
+
logger.info("Compiling model with torch.compile (first inference will be slow)...")
|
| 106 |
+
_model_4b = torch.compile(_model_4b, mode="reduce-overhead")
|
| 107 |
+
|
| 108 |
logger.info("MedGemma 4B loaded.")
|
| 109 |
return _model_4b, _processor_4b
|
| 110 |
|
|
|
|
| 123 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 124 |
|
| 125 |
is_local = _is_local_path(MEDGEMMA_27B_MODEL_ID)
|
| 126 |
+
opts = ["bf16"]
|
| 127 |
+
if ENABLE_SDPA:
|
| 128 |
+
opts.append("SDPA")
|
| 129 |
+
if ENABLE_TORCH_COMPILE:
|
| 130 |
+
opts.append("compiled")
|
| 131 |
logger.info(
|
| 132 |
+
"Loading MedGemma 27B Text-IT (%s) from %s...",
|
| 133 |
+
"+".join(opts),
|
| 134 |
"local" if is_local else "HF Hub",
|
| 135 |
)
|
| 136 |
|
| 137 |
+
kwargs = {
|
|
|
|
|
|
|
| 138 |
**_token_arg(MEDGEMMA_27B_MODEL_ID),
|
| 139 |
+
"torch_dtype": torch.bfloat16,
|
| 140 |
+
"device_map": "auto",
|
| 141 |
+
}
|
| 142 |
+
if ENABLE_SDPA:
|
| 143 |
+
kwargs["attn_implementation"] = "sdpa"
|
| 144 |
+
|
| 145 |
+
_tokenizer_27b = AutoTokenizer.from_pretrained(MEDGEMMA_27B_MODEL_ID, **_token_arg(MEDGEMMA_27B_MODEL_ID))
|
| 146 |
+
_model_27b = AutoModelForCausalLM.from_pretrained(MEDGEMMA_27B_MODEL_ID, **kwargs)
|
| 147 |
_model_27b.eval()
|
| 148 |
+
|
| 149 |
+
if ENABLE_TORCH_COMPILE:
|
| 150 |
+
logger.info("Compiling model with torch.compile (first inference will be slow)...")
|
| 151 |
+
_model_27b = torch.compile(_model_27b, mode="reduce-overhead")
|
| 152 |
+
|
| 153 |
logger.info("MedGemma 27B loaded.")
|
| 154 |
return _model_27b, _tokenizer_27b
|
| 155 |
|