Gemma3n Slicing
Hi,
I tried the official Gemma-3n slicing methods to create a text-only, smaller model, but the outputs became gibberish with mixed languages.
Has anyone faced this issue or knows what might be going wrong in the slicing process? Any guidance would be appreciated.
Hi @varshu23 ,
Thanks for reaching out to us! Could you please share all the reproducible steps along with a minimal code snippet so we can understand the issue better?
Hi Sonali,
Thanks for getting back to me.
Below are the reproducible steps along with a minimal explanation of what I did and the issue I’m facing.
Base Model : google/gemma-3n-E4B-it
Step 1: Model Surgery (Text-only + Size Reduction)
- Removed multimodal components (vision_tower, audio_tower, multi_modal_projector, etc.).
- Skipped decoder layers: layers_to_skip = [20–29]
- Reduced FFN hidden dimension for each kept layer to: target_ffn_dim = 8192
Sliced code:
from safetensors import safe_open
from tqdm.auto import tqdm
import re
import torch
import gc
from safetensors.torch import save_file
kept_layers_indices = [i for i in range(num_layers) if i not in layers_to_skip]
layer_rename_map = {old_idx: new_idx for new_idx, old_idx in enumerate(kept_layers_indices)}
This will store the mapping of tensor names to the file they are saved in
weight_map = {}
This will store tensors for the current shard we are building
new_shard_state_dict = {}
shard_counter = 1
total_size = 0
pbar = tqdm(total=len(safetensor_files), desc="Processing shards")
for shard_path in safetensor_files:
# Open a shard for streaming
with safe_open(shard_path, framework="pt", device="cpu") as f:
# Iterate over each tensor in the shard
for tensor_name in f.keys():
if any(x in tensor_name for x in ['vision_tower', 'multi_modal_projector', 'audio_tower', 'visual_adapter']):
# specific print to confirm we are deleting it
# print(f"Removing multimodal tensor: {tensor_name}")
continue
new_tensor_name = tensor_name
tensor = f.get_tensor(tensor_name)
# Case 1: Handle layer-specific parameters
match = re.search(r'\.layers\.(\d+)\.', tensor_name)
if match:
old_layer_idx = int(match.group(1))
# If this layer is meant to be skipped, we just continue to the next tensor
if old_layer_idx in layers_to_skip:
continue
# Get the new sequential layer index
new_layer_idx = layer_rename_map[old_layer_idx]
new_tensor_name = tensor_name.replace(
f'.layers.{old_layer_idx}.',
f'.layers.{new_layer_idx}.'
)
# Get the target FFN dimension for this new layer
target_ffn_dim = ffn_hidden_dims[new_layer_idx]
# Check if this parameter is part of the FFN and needs slicing
if 'mlp.gate_proj.weight' in new_tensor_name or 'mlp.up_proj.weight' in new_tensor_name:
# These layers project from model_dim -> ffn_hidden_dim.
# We slice the output dimension (dim 0).
tensor = tensor[:target_ffn_dim, :].contiguous()
elif 'mlp.down_proj.weight' in new_tensor_name:
# This layer projects from ffn_hidden_dim -> model_dim.
# We slice the input dimension (dim 1).
tensor = tensor[:, :target_ffn_dim].contiguous()
# Case 2: Handle special non-layer parameters that need slicing
elif 'per_layer_model_projection' in tensor_name:
# Reshape, slice based on kept layers, and reshape back
reshaped_params = tensor.reshape((num_layers, tensor.shape[0] // num_layers, tensor.shape[1]))
tensor = reshaped_params[kept_layers_indices, :, :]
tensor = tensor.reshape(-1, tensor.shape[-1]).contiguous()
elif 'embed_tokens_per_layer' in tensor_name:
# Reshape, slice based on kept layers, and reshape back
reshaped_params = tensor.reshape((tensor.shape[0], num_layers, tensor.shape[1] // num_layers))
tensor = reshaped_params[:, kept_layers_indices, :]
tensor = tensor.reshape(tensor.shape[0], -1).contiguous()
# Add the (potentially modified) tensor to the new shard
new_shard_state_dict[new_tensor_name] = tensor
# Check if the current shard is getting too big
current_shard_size = sum(t.numel() * t.element_size() for t in new_shard_state_dict.values())
if current_shard_size > 4000000000: # Create new shard if current is over 4GB
shard_filename = f"model-{(shard_counter):05d}-of-XXXXX.safetensors"
print(f"Saving shard {shard_filename} (size: {current_shard_size / 1e9:.2f} GB)")
save_file(new_shard_state_dict, os.path.join(local_output_path, shard_filename), metadata={'format': 'pt'})
# Record which tensors are in this shard
for k in new_shard_state_dict.keys():
weight_map[k] = os.path.basename(shard_filename)
# Reset for the next shard
shard_counter += 1
new_shard_state_dict = {}
gc.collect() # Free up memory
pbar.update(1)
pbar.close()
Step 2: Upload to Hugging Face
- Uploaded the processed model to: varshu23/gemma_3n_v8
Step 3: Inference Test (Hugging Face Space)
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "varshu23/gemma_3n_v8"
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
HF_TOKEN = HF_TOKEN.strip()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device found: {device}")
print(f"Loading text-only model from: {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
trust_remote_code=True
).to(device) # Move entire model to device explicitly
model.eval()
def generate_response(prompt):
# Hardcoded parameters
max_tokens = 256
temperature = 1.0
messages = [{"role": "user", "content": prompt}]
# Generate inputs
inputs_data = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
)
if isinstance(inputs_data, dict) or hasattr(inputs_data, "input_ids"):
input_ids = inputs_data["input_ids"].to(device)
attention_mask = inputs_data.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(device)
else:
input_ids = inputs_data.to(device)
attention_mask = None
with torch.no_grad():
out = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True if temperature > 0 else False,
top_p=0.95,
top_k=64,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(out[0][input_ids.shape[-1]:], skip_special_tokens=True)
return generated_text
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# Gemma 3n Text-Only\nModel: {MODEL_ID}")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Prompt", placeholder="Ask me anything...", lines=5)
submit_btn = gr.Button("Generate", variant="primary")
output_text = gr.Textbox(label="Model Output", lines=10)
submit_btn.click(
fn=generate_response,
inputs=[input_text],
outputs=output_text
)
if name == "main":
demo.launch()
Observed Issue
The model runs without crashing, but the generated output is completely garbled / unreadable (random multilingual symbols and corrupted text).
Example input: explain gemma
Example output: ti/,,,:0...u 茷I.aOkk"" میکنندสงค์.鿓"..等b1...
Expectation
- Either meaningful English output or a clear error indicating incompatibility.
- This looks like a silent corruption or config/weight mismatch rather than a tokenizer issue.
Please let me know if you’d like me to share
Thanks for your time and help!
You can test the model inference live at the link below.
https://huggingface.co/spaces/varshu23/gemma-3n-text-only-space