Instructions to use google/gemma-2-9b with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use google/gemma-2-9b with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="google/gemma-2-9b")# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") - Inference
- Local Apps Settings
- vLLM
How to use google/gemma-2-9b with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "google/gemma-2-9b" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "google/gemma-2-9b", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/google/gemma-2-9b
- SGLang
How to use google/gemma-2-9b with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "google/gemma-2-9b" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "google/gemma-2-9b", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "google/gemma-2-9b" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "google/gemma-2-9b", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use google/gemma-2-9b with Docker Model Runner:
docker model run hf.co/google/gemma-2-9b
Runtime autograd error due to inplace operations
#4
by xianbin - opened
Error
While performing fine tuning of the Gemma2 models using TRL, the following errors were encountered:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [CUDABFloat16Type [1, 308, 256000]], which is output 0 of TanhBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Cause
This was found to be due to the use of inplace operations in the Gemma2 transformer model definition that modifies a variable needed for gradient computation
Possible solution
The following lines of codes should be modified in diff_gemma2.py (and by extension modeling_gemma2.py)
Line 163-165:
attention_mask *= torch.tril(
torch.ones_like(attention_mask),
diagonal=(self.sliding_window - cache_position[-1]),
)
Replacement:
attention_mask = torch.mul(
attention_mask,
torch.tril(
torch.ones_like(attention_mask),
diagonal=(self.sliding_window - cache_position[-1]),
),
)
Line 119-121:
attn_weights.div_(self.config.attn_logit_softcapping)
attn_weights = torch.tanh(attn_weights)
attn_weights.mul_(self.config.attn_logit_softcapping)
Replacement:
attn_weights = torch.div(attn_weights, self.config.attn_logit_softcapping)
attn_weights = self.attn_weights_tanh(attn_weights)
attn_weights = torch.mul(attn_weights, self.config.attn_logit_softcapping)
Place this in the init of Gemma2Attention:
self.attn_weights_tanh = nn.Tanh()
Line 202-204:
logits.div_(self.config.final_logit_softcapping)
logits = torch.tanh(logits)
logits.mul_(self.config.final_logit_softcapping)
Replacement:
logits = torch.div(logits, self.config.final_logit_softcapping)
logits = self.final_logit_tanh(logits)
logits = torch.mul(logits, self.config.final_logit_softcapping)
Place this in the init of Gemma2ForCausalLM:
self.final_logit_tanh = nn.Tanh()
Yes will fix this in a bit!
xianbin changed discussion status to closed