gemma4-e4b / app.py
rahul7star's picture
Update app.py
8819a05 verified
import gradio as gr
import torch
import time
import traceback
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
model_id = "rahul7star/gemma-4-finetune"
def log(msg):
print(f"[DEBUG] {msg}", flush=True)
log("Starting Gemma 4 debug app")
log(f"Model ID: {model_id}")
log(f"Torch version: {torch.__version__}")
log(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
log(f"CUDA device count: {torch.cuda.device_count()}")
log(f"CUDA device name: {torch.cuda.get_device_name(0)}")
# ============================================================
# Load Tokenizer
# ============================================================
log("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
)
log("Tokenizer loaded")
log(f"Tokenizer class: {tokenizer.__class__.__name__}")
log(f"Vocab size: {len(tokenizer)}")
log(f"EOS token: {tokenizer.eos_token} / {tokenizer.eos_token_id}")
log(f"PAD token: {tokenizer.pad_token} / {tokenizer.pad_token_id}")
log(f"Chat template exists: {tokenizer.chat_template is not None}")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
log("PAD token was missing, set PAD token = EOS token")
# ============================================================
# Load Model
# ============================================================
log("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model.eval()
log("Model loaded")
log(f"Model class: {model.__class__.__name__}")
log(f"Model device: {model.device}")
log(f"Model dtype: {next(model.parameters()).dtype}")
# ============================================================
# Config Logs
# ============================================================
cfg = model.config
text_cfg = getattr(cfg, "text_config", None)
vision_cfg = getattr(cfg, "vision_config", None)
log("========== MAIN MODEL CONFIG ==========")
log(f"model_type: {getattr(cfg, 'model_type', None)}")
log(f"architectures: {getattr(cfg, 'architectures', None)}")
log(f"is_encoder_decoder: {getattr(cfg, 'is_encoder_decoder', None)}")
log(f"text_config exists: {text_cfg is not None}")
log(f"vision_config exists: {vision_cfg is not None}")
log("=======================================")
if text_cfg is not None:
log("========== TEXT CONFIG ==========")
log(f"model_type: {getattr(text_cfg, 'model_type', None)}")
log(f"hidden_size: {getattr(text_cfg, 'hidden_size', None)}")
log(f"intermediate_size: {getattr(text_cfg, 'intermediate_size', None)}")
log(f"num_hidden_layers: {getattr(text_cfg, 'num_hidden_layers', None)}")
log(f"num_attention_heads: {getattr(text_cfg, 'num_attention_heads', None)}")
log(f"num_key_value_heads: {getattr(text_cfg, 'num_key_value_heads', None)}")
log(f"head_dim: {getattr(text_cfg, 'head_dim', None)}")
log(f"vocab_size: {getattr(text_cfg, 'vocab_size', None)}")
log(f"max_position_embeddings: {getattr(text_cfg, 'max_position_embeddings', None)}")
log(f"rope_theta: {getattr(text_cfg, 'rope_theta', None)}")
log(f"rms_norm_eps: {getattr(text_cfg, 'rms_norm_eps', None)}")
log(f"attention_bias: {getattr(text_cfg, 'attention_bias', None)}")
log(f"use_cache: {getattr(text_cfg, 'use_cache', None)}")
log(f"sliding_window: {getattr(text_cfg, 'sliding_window', None)}")
log(f"query_pre_attn_scalar: {getattr(text_cfg, 'query_pre_attn_scalar', None)}")
log(f"final_logit_softcapping: {getattr(text_cfg, 'final_logit_softcapping', None)}")
log(f"attn_logit_softcapping: {getattr(text_cfg, 'attn_logit_softcapping', None)}")
log("=================================")
if vision_cfg is not None:
log("========== VISION CONFIG ==========")
log(f"model_type: {getattr(vision_cfg, 'model_type', None)}")
log(f"hidden_size: {getattr(vision_cfg, 'hidden_size', None)}")
log(f"intermediate_size: {getattr(vision_cfg, 'intermediate_size', None)}")
log(f"num_hidden_layers: {getattr(vision_cfg, 'num_hidden_layers', None)}")
log(f"num_attention_heads: {getattr(vision_cfg, 'num_attention_heads', None)}")
log(f"image_size: {getattr(vision_cfg, 'image_size', None)}")
log(f"patch_size: {getattr(vision_cfg, 'patch_size', None)}")
log("===================================")
# ============================================================
# Parameter Logs
# ============================================================
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log(f"Total parameters: {total_params:,}")
log(f"Trainable parameters: {trainable_params:,}")
# ============================================================
# Module Listing
# ============================================================
log("========== TEXT MODEL MODULES ==========")
text_keywords = [
"language_model",
"text_model",
"model.layers",
"self_attn",
"mlp",
"input_layernorm",
"post_attention_layernorm",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"rotary",
"embed_tokens",
"lm_head",
]
count = 0
for name, module in model.named_modules():
lower = name.lower()
if "vision_tower" in lower:
continue
if any(k in lower for k in text_keywords):
log(f"{name} => {module.__class__.__name__}")
count += 1
if count >= 200:
log("Stopped text module logging after 200 entries")
break
log(f"Text modules logged: {count}")
log("========================================")
# ============================================================
# Deep Forward Hooks
# ============================================================
DEBUG_HOOKS = True
HOOK_EVERY_N_CALLS = 20
_hook_calls = {}
def tensor_stats(x):
if not torch.is_tensor(x):
return str(type(x))
with torch.no_grad():
xf = x.detach().float()
return (
f"shape={tuple(x.shape)}, "
f"dtype={x.dtype}, "
f"device={x.device}, "
f"mean={xf.mean().item():.5f}, "
f"std={xf.std().item():.5f}, "
f"min={xf.min().item():.5f}, "
f"max={xf.max().item():.5f}"
)
def get_first_tensor(obj):
if torch.is_tensor(obj):
return obj
if isinstance(obj, (list, tuple)):
for item in obj:
t = get_first_tensor(item)
if t is not None:
return t
if isinstance(obj, dict):
for item in obj.values():
t = get_first_tensor(item)
if t is not None:
return t
return None
def make_hook(name):
def hook(module, inputs, output):
if not DEBUG_HOOKS:
return
_hook_calls[name] = _hook_calls.get(name, 0) + 1
if _hook_calls[name] % HOOK_EVERY_N_CALLS != 1:
return
inp = get_first_tensor(inputs)
out = get_first_tensor(output)
log(f"HOOK: {name}")
if inp is not None:
log(f" input -> {tensor_stats(inp)}")
if out is not None:
log(f" output -> {tensor_stats(out)}")
return hook
def attach_debug_hooks():
wanted = [
"self_attn",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"mlp",
"gate_proj",
"up_proj",
"down_proj",
"input_layernorm",
"post_attention_layernorm",
"rotary_emb",
"lm_head",
]
attached = 0
for name, module in model.named_modules():
lower = name.lower()
if "vision_tower" in lower:
continue
if any(w in lower for w in wanted):
module.register_forward_hook(make_hook(name))
attached += 1
log(f"Attached debug hooks: {attached}")
#attach_debug_hooks()
# ============================================================
# Generation Function
# ============================================================
def generate_response(message, history):
start_time = time.time()
log("========== NEW GENERATION ==========")
log(f"User message: {message}")
log(f"History turns: {len(history)}")
messages = []
for item in history:
try:
user_msg, bot_msg = item
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
except Exception as e:
log(f"History parse warning: {e}")
log(f"Bad history item: {item}")
messages.append({"role": "user", "content": message})
log(f"Total chat messages: {len(messages)}")
try:
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
).to(model.device)
input_token_count = inputs["input_ids"].shape[-1]
log(f"Input tensor shape: {inputs['input_ids'].shape}")
log(f"Input tokens: {input_token_count}")
log(f"Input device: {inputs['input_ids'].device}")
log("========== TOKEN DEBUG ==========")
ids = inputs["input_ids"][0].tolist()
log(f"First 20 token ids: {ids[:20]}")
log(f"Last 20 token ids: {ids[-20:]}")
log(f"Decoded prompt preview: {tokenizer.decode(ids[-200:], skip_special_tokens=False)}")
log("=================================")
except Exception as e:
log("Chat template/tokenization failed")
log(traceback.format_exc())
yield f"Tokenization error: {e}"
return
streamer = TextIteratorStreamer(
tokenizer,
timeout=420.0,
skip_prompt=True,
skip_special_tokens=True,
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=1024,
temperature=0.7,
do_sample=False,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
log("Generation kwargs:")
log("max_new_tokens=1024")
log("temperature=0.7")
log("do_sample=True")
log("top_p=0.9")
def run_generation():
try:
log("Generation thread started")
gen_start = time.time()
with torch.no_grad():
model.generate(**generate_kwargs)
gen_time = time.time() - gen_start
log(f"Generation thread finished in {gen_time:.2f}s")
except Exception as e:
log("Generation Error")
log(traceback.format_exc())
streamer.text_queue.put(
f"\n[Generation thread crashed. Reason: {e}]"
)
streamer.end()
t = Thread(target=run_generation)
t.start()
partial_text = ""
token_chunks = 0
try:
for new_text in streamer:
token_chunks += 1
partial_text += new_text
if token_chunks % 20 == 0:
elapsed = time.time() - start_time
log(
f"Streaming chunks: {token_chunks}, "
f"chars: {len(partial_text)}, "
f"elapsed: {elapsed:.2f}s"
)
yield partial_text
except Exception as e:
log("Streaming Error")
log(traceback.format_exc())
yield partial_text + f"\n\n[Streaming error: {e}]"
finally:
elapsed = time.time() - start_time
log("========== GENERATION DONE ==========")
log(f"Output chars: {len(partial_text)}")
log(f"Streaming chunks: {token_chunks}")
log(f"Elapsed seconds: {elapsed:.2f}")
log("=====================================")
# ============================================================
# Gradio UI
# ============================================================
demo = gr.ChatInterface(
fn=generate_response,
title="Gemma 4 E4B - Debug",
examples=[
"Explain quantum entanglement simply.",
"Write a Python function to add two numbers.",
"Explain how RoPE works in transformer attention.",
],
cache_examples=False,
)
if __name__ == "__main__":
log("Launching Gradio app...")
demo.launch()