import gradio as gr import torch import time import traceback from threading import Thread from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer model_id = "rahul7star/gemma-4-finetune" def log(msg): print(f"[DEBUG] {msg}", flush=True) log("Starting Gemma 4 debug app") log(f"Model ID: {model_id}") log(f"Torch version: {torch.__version__}") log(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): log(f"CUDA device count: {torch.cuda.device_count()}") log(f"CUDA device name: {torch.cuda.get_device_name(0)}") # ============================================================ # Load Tokenizer # ============================================================ log("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, ) log("Tokenizer loaded") log(f"Tokenizer class: {tokenizer.__class__.__name__}") log(f"Vocab size: {len(tokenizer)}") log(f"EOS token: {tokenizer.eos_token} / {tokenizer.eos_token_id}") log(f"PAD token: {tokenizer.pad_token} / {tokenizer.pad_token_id}") log(f"Chat template exists: {tokenizer.chat_template is not None}") if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token log("PAD token was missing, set PAD token = EOS token") # ============================================================ # Load Model # ============================================================ log("Loading model...") model = AutoModelForCausalLM.from_pretrained( model_id, device_map="cpu", low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, trust_remote_code=True, ) model.eval() log("Model loaded") log(f"Model class: {model.__class__.__name__}") log(f"Model device: {model.device}") log(f"Model dtype: {next(model.parameters()).dtype}") # ============================================================ # Config Logs # ============================================================ cfg = model.config text_cfg = getattr(cfg, "text_config", None) vision_cfg = getattr(cfg, "vision_config", None) log("========== MAIN MODEL CONFIG ==========") log(f"model_type: {getattr(cfg, 'model_type', None)}") log(f"architectures: {getattr(cfg, 'architectures', None)}") log(f"is_encoder_decoder: {getattr(cfg, 'is_encoder_decoder', None)}") log(f"text_config exists: {text_cfg is not None}") log(f"vision_config exists: {vision_cfg is not None}") log("=======================================") if text_cfg is not None: log("========== TEXT CONFIG ==========") log(f"model_type: {getattr(text_cfg, 'model_type', None)}") log(f"hidden_size: {getattr(text_cfg, 'hidden_size', None)}") log(f"intermediate_size: {getattr(text_cfg, 'intermediate_size', None)}") log(f"num_hidden_layers: {getattr(text_cfg, 'num_hidden_layers', None)}") log(f"num_attention_heads: {getattr(text_cfg, 'num_attention_heads', None)}") log(f"num_key_value_heads: {getattr(text_cfg, 'num_key_value_heads', None)}") log(f"head_dim: {getattr(text_cfg, 'head_dim', None)}") log(f"vocab_size: {getattr(text_cfg, 'vocab_size', None)}") log(f"max_position_embeddings: {getattr(text_cfg, 'max_position_embeddings', None)}") log(f"rope_theta: {getattr(text_cfg, 'rope_theta', None)}") log(f"rms_norm_eps: {getattr(text_cfg, 'rms_norm_eps', None)}") log(f"attention_bias: {getattr(text_cfg, 'attention_bias', None)}") log(f"use_cache: {getattr(text_cfg, 'use_cache', None)}") log(f"sliding_window: {getattr(text_cfg, 'sliding_window', None)}") log(f"query_pre_attn_scalar: {getattr(text_cfg, 'query_pre_attn_scalar', None)}") log(f"final_logit_softcapping: {getattr(text_cfg, 'final_logit_softcapping', None)}") log(f"attn_logit_softcapping: {getattr(text_cfg, 'attn_logit_softcapping', None)}") log("=================================") if vision_cfg is not None: log("========== VISION CONFIG ==========") log(f"model_type: {getattr(vision_cfg, 'model_type', None)}") log(f"hidden_size: {getattr(vision_cfg, 'hidden_size', None)}") log(f"intermediate_size: {getattr(vision_cfg, 'intermediate_size', None)}") log(f"num_hidden_layers: {getattr(vision_cfg, 'num_hidden_layers', None)}") log(f"num_attention_heads: {getattr(vision_cfg, 'num_attention_heads', None)}") log(f"image_size: {getattr(vision_cfg, 'image_size', None)}") log(f"patch_size: {getattr(vision_cfg, 'patch_size', None)}") log("===================================") # ============================================================ # Parameter Logs # ============================================================ total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) log(f"Total parameters: {total_params:,}") log(f"Trainable parameters: {trainable_params:,}") # ============================================================ # Module Listing # ============================================================ log("========== TEXT MODEL MODULES ==========") text_keywords = [ "language_model", "text_model", "model.layers", "self_attn", "mlp", "input_layernorm", "post_attention_layernorm", "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "rotary", "embed_tokens", "lm_head", ] count = 0 for name, module in model.named_modules(): lower = name.lower() if "vision_tower" in lower: continue if any(k in lower for k in text_keywords): log(f"{name} => {module.__class__.__name__}") count += 1 if count >= 200: log("Stopped text module logging after 200 entries") break log(f"Text modules logged: {count}") log("========================================") # ============================================================ # Deep Forward Hooks # ============================================================ DEBUG_HOOKS = True HOOK_EVERY_N_CALLS = 20 _hook_calls = {} def tensor_stats(x): if not torch.is_tensor(x): return str(type(x)) with torch.no_grad(): xf = x.detach().float() return ( f"shape={tuple(x.shape)}, " f"dtype={x.dtype}, " f"device={x.device}, " f"mean={xf.mean().item():.5f}, " f"std={xf.std().item():.5f}, " f"min={xf.min().item():.5f}, " f"max={xf.max().item():.5f}" ) def get_first_tensor(obj): if torch.is_tensor(obj): return obj if isinstance(obj, (list, tuple)): for item in obj: t = get_first_tensor(item) if t is not None: return t if isinstance(obj, dict): for item in obj.values(): t = get_first_tensor(item) if t is not None: return t return None def make_hook(name): def hook(module, inputs, output): if not DEBUG_HOOKS: return _hook_calls[name] = _hook_calls.get(name, 0) + 1 if _hook_calls[name] % HOOK_EVERY_N_CALLS != 1: return inp = get_first_tensor(inputs) out = get_first_tensor(output) log(f"HOOK: {name}") if inp is not None: log(f" input -> {tensor_stats(inp)}") if out is not None: log(f" output -> {tensor_stats(out)}") return hook def attach_debug_hooks(): wanted = [ "self_attn", "q_proj", "k_proj", "v_proj", "o_proj", "mlp", "gate_proj", "up_proj", "down_proj", "input_layernorm", "post_attention_layernorm", "rotary_emb", "lm_head", ] attached = 0 for name, module in model.named_modules(): lower = name.lower() if "vision_tower" in lower: continue if any(w in lower for w in wanted): module.register_forward_hook(make_hook(name)) attached += 1 log(f"Attached debug hooks: {attached}") #attach_debug_hooks() # ============================================================ # Generation Function # ============================================================ def generate_response(message, history): start_time = time.time() log("========== NEW GENERATION ==========") log(f"User message: {message}") log(f"History turns: {len(history)}") messages = [] for item in history: try: user_msg, bot_msg = item messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": bot_msg}) except Exception as e: log(f"History parse warning: {e}") log(f"Bad history item: {item}") messages.append({"role": "user", "content": message}) log(f"Total chat messages: {len(messages)}") try: inputs = tokenizer.apply_chat_template( messages, return_tensors="pt", return_dict=True, add_generation_prompt=True, ).to(model.device) input_token_count = inputs["input_ids"].shape[-1] log(f"Input tensor shape: {inputs['input_ids'].shape}") log(f"Input tokens: {input_token_count}") log(f"Input device: {inputs['input_ids'].device}") log("========== TOKEN DEBUG ==========") ids = inputs["input_ids"][0].tolist() log(f"First 20 token ids: {ids[:20]}") log(f"Last 20 token ids: {ids[-20:]}") log(f"Decoded prompt preview: {tokenizer.decode(ids[-200:], skip_special_tokens=False)}") log("=================================") except Exception as e: log("Chat template/tokenization failed") log(traceback.format_exc()) yield f"Tokenization error: {e}" return streamer = TextIteratorStreamer( tokenizer, timeout=420.0, skip_prompt=True, skip_special_tokens=True, ) generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=1024, temperature=0.7, do_sample=False, top_p=0.9, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) log("Generation kwargs:") log("max_new_tokens=1024") log("temperature=0.7") log("do_sample=True") log("top_p=0.9") def run_generation(): try: log("Generation thread started") gen_start = time.time() with torch.no_grad(): model.generate(**generate_kwargs) gen_time = time.time() - gen_start log(f"Generation thread finished in {gen_time:.2f}s") except Exception as e: log("Generation Error") log(traceback.format_exc()) streamer.text_queue.put( f"\n[Generation thread crashed. Reason: {e}]" ) streamer.end() t = Thread(target=run_generation) t.start() partial_text = "" token_chunks = 0 try: for new_text in streamer: token_chunks += 1 partial_text += new_text if token_chunks % 20 == 0: elapsed = time.time() - start_time log( f"Streaming chunks: {token_chunks}, " f"chars: {len(partial_text)}, " f"elapsed: {elapsed:.2f}s" ) yield partial_text except Exception as e: log("Streaming Error") log(traceback.format_exc()) yield partial_text + f"\n\n[Streaming error: {e}]" finally: elapsed = time.time() - start_time log("========== GENERATION DONE ==========") log(f"Output chars: {len(partial_text)}") log(f"Streaming chunks: {token_chunks}") log(f"Elapsed seconds: {elapsed:.2f}") log("=====================================") # ============================================================ # Gradio UI # ============================================================ demo = gr.ChatInterface( fn=generate_response, title="Gemma 4 E4B - Debug", examples=[ "Explain quantum entanglement simply.", "Write a Python function to add two numbers.", "Explain how RoPE works in transformer attention.", ], cache_examples=False, ) if __name__ == "__main__": log("Launching Gradio app...") demo.launch()