Gemma3n Slicing

#22
by varshu23 - opened

Hi,
I tried the official Gemma-3n slicing methods to create a text-only, smaller model, but the outputs became gibberish with mixed languages.
Has anyone faced this issue or knows what might be going wrong in the slicing process? Any guidance would be appreciated.

Hi @varshu23 ,

Thanks for reaching out to us! Could you please share all the reproducible steps along with a minimal code snippet so we can understand the issue better?

Hi Sonali,

Thanks for getting back to me.
Below are the reproducible steps along with a minimal explanation of what I did and the issue I’m facing.
Base Model : google/gemma-3n-E4B-it

Step 1: Model Surgery (Text-only + Size Reduction)

  • Removed multimodal components (vision_tower, audio_tower, multi_modal_projector, etc.).
  • Skipped decoder layers:
layers_to_skip = [20–29]
  • Reduced FFN hidden dimension for each kept layer to:
target_ffn_dim = 8192

Sliced code:
from safetensors import safe_open
from tqdm.auto import tqdm
import re
import torch
import gc

from safetensors.torch import save_file

kept_layers_indices = [i for i in range(num_layers) if i not in layers_to_skip]
layer_rename_map = {old_idx: new_idx for new_idx, old_idx in enumerate(kept_layers_indices)}

This will store the mapping of tensor names to the file they are saved in

weight_map = {}

This will store tensors for the current shard we are building

new_shard_state_dict = {}
shard_counter = 1
total_size = 0

pbar = tqdm(total=len(safetensor_files), desc="Processing shards")

for shard_path in safetensor_files:
# Open a shard for streaming
with safe_open(shard_path, framework="pt", device="cpu") as f:
# Iterate over each tensor in the shard
for tensor_name in f.keys():

        if any(x in tensor_name for x in ['vision_tower', 'multi_modal_projector', 'audio_tower', 'visual_adapter']):
            # specific print to confirm we are deleting it
            # print(f"Removing multimodal tensor: {tensor_name}")
            continue

        new_tensor_name = tensor_name
        tensor = f.get_tensor(tensor_name)

        # Case 1: Handle layer-specific parameters
        match = re.search(r'\.layers\.(\d+)\.', tensor_name)
        if match:
            old_layer_idx = int(match.group(1))

            # If this layer is meant to be skipped, we just continue to the next tensor
            if old_layer_idx in layers_to_skip:
                continue

            # Get the new sequential layer index
            new_layer_idx = layer_rename_map[old_layer_idx]
            new_tensor_name = tensor_name.replace(
                f'.layers.{old_layer_idx}.',
                f'.layers.{new_layer_idx}.'
            )

            # Get the target FFN dimension for this new layer
            target_ffn_dim = ffn_hidden_dims[new_layer_idx]

            # Check if this parameter is part of the FFN and needs slicing
            if 'mlp.gate_proj.weight' in new_tensor_name or 'mlp.up_proj.weight' in new_tensor_name:
                # These layers project from model_dim -> ffn_hidden_dim.
                # We slice the output dimension (dim 0).
                tensor = tensor[:target_ffn_dim, :].contiguous()
            elif 'mlp.down_proj.weight' in new_tensor_name:
                # This layer projects from ffn_hidden_dim -> model_dim.
                # We slice the input dimension (dim 1).
                tensor = tensor[:, :target_ffn_dim].contiguous()

        # Case 2: Handle special non-layer parameters that need slicing
        elif 'per_layer_model_projection' in tensor_name:
            # Reshape, slice based on kept layers, and reshape back
            reshaped_params = tensor.reshape((num_layers, tensor.shape[0] // num_layers, tensor.shape[1]))
            tensor = reshaped_params[kept_layers_indices, :, :]
            tensor = tensor.reshape(-1, tensor.shape[-1]).contiguous()

        elif 'embed_tokens_per_layer' in tensor_name:
            # Reshape, slice based on kept layers, and reshape back
            reshaped_params = tensor.reshape((tensor.shape[0], num_layers, tensor.shape[1] // num_layers))
            tensor = reshaped_params[:, kept_layers_indices, :]
            tensor = tensor.reshape(tensor.shape[0], -1).contiguous()

        # Add the (potentially modified) tensor to the new shard
        new_shard_state_dict[new_tensor_name] = tensor

        # Check if the current shard is getting too big
        current_shard_size = sum(t.numel() * t.element_size() for t in new_shard_state_dict.values())
        if current_shard_size > 4000000000: # Create new shard if current is over 4GB
            shard_filename = f"model-{(shard_counter):05d}-of-XXXXX.safetensors"
            print(f"Saving shard {shard_filename} (size: {current_shard_size / 1e9:.2f} GB)")
            save_file(new_shard_state_dict, os.path.join(local_output_path, shard_filename), metadata={'format': 'pt'})

            # Record which tensors are in this shard
            for k in new_shard_state_dict.keys():
                weight_map[k] = os.path.basename(shard_filename)

            # Reset for the next shard
            shard_counter += 1
            new_shard_state_dict = {}
            gc.collect() # Free up memory
pbar.update(1)

pbar.close()

Step 2: Upload to Hugging Face

  • Uploaded the processed model to:
varshu23/gemma_3n_v8

Step 3: Inference Test (Hugging Face Space)
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "varshu23/gemma_3n_v8"
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
HF_TOKEN = HF_TOKEN.strip()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device found: {device}")
print(f"Loading text-only model from: {MODEL_ID}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
trust_remote_code=True
).to(device) # Move entire model to device explicitly

model.eval()

def generate_response(prompt):
# Hardcoded parameters
max_tokens = 256
temperature = 1.0

messages = [{"role": "user", "content": prompt}]

# Generate inputs
inputs_data = tokenizer.apply_chat_template(
    messages, 
    add_generation_prompt=True, 
    return_tensors="pt"
)


if isinstance(inputs_data, dict) or hasattr(inputs_data, "input_ids"):
    input_ids = inputs_data["input_ids"].to(device)
    attention_mask = inputs_data.get("attention_mask")
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)
else:
    input_ids = inputs_data.to(device)
    attention_mask = None

with torch.no_grad():
    out = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_tokens,
        temperature=temperature,
        do_sample=True if temperature > 0 else False,
        top_p=0.95,
        top_k=64,
        pad_token_id=tokenizer.eos_token_id
    )


generated_text = tokenizer.decode(out[0][input_ids.shape[-1]:], skip_special_tokens=True)
return generated_text

with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# Gemma 3n Text-Only\nModel: {MODEL_ID}")

with gr.Row():
    with gr.Column():
        input_text = gr.Textbox(label="Input Prompt", placeholder="Ask me anything...", lines=5)
        submit_btn = gr.Button("Generate", variant="primary")
    
    output_text = gr.Textbox(label="Model Output", lines=10)

submit_btn.click(
    fn=generate_response, 
    inputs=[input_text], 
    outputs=output_text
)

if name == "main":
demo.launch()

Observed Issue
The model runs without crashing, but the generated output is completely garbled / unreadable (random multilingual symbols and corrupted text).
Example input: explain gemma

Example output:
 ti/,,,:0...u 茷I.aOkk"" می‌کنندสงค์.鿓"..等b1...

Expectation

  • Either meaningful English output or a clear error indicating incompatibility.
  • This looks like a silent corruption or config/weight mismatch rather than a tokenizer issue.

Please let me know if you’d like me to share

Thanks for your time and help!

You can test the model inference live at the link below.
https://huggingface.co/spaces/varshu23/gemma-3n-text-only-space

Sign up or log in to comment