Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import time | |
| import traceback | |
| from threading import Thread | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| model_id = "rahul7star/gemma-4-finetune" | |
| def log(msg): | |
| print(f"[DEBUG] {msg}", flush=True) | |
| log("Starting Gemma 4 debug app") | |
| log(f"Model ID: {model_id}") | |
| log(f"Torch version: {torch.__version__}") | |
| log(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| log(f"CUDA device count: {torch.cuda.device_count()}") | |
| log(f"CUDA device name: {torch.cuda.get_device_name(0)}") | |
| # ============================================================ | |
| # Load Tokenizer | |
| # ============================================================ | |
| log("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| ) | |
| log("Tokenizer loaded") | |
| log(f"Tokenizer class: {tokenizer.__class__.__name__}") | |
| log(f"Vocab size: {len(tokenizer)}") | |
| log(f"EOS token: {tokenizer.eos_token} / {tokenizer.eos_token_id}") | |
| log(f"PAD token: {tokenizer.pad_token} / {tokenizer.pad_token_id}") | |
| log(f"Chat template exists: {tokenizer.chat_template is not None}") | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| log("PAD token was missing, set PAD token = EOS token") | |
| # ============================================================ | |
| # Load Model | |
| # ============================================================ | |
| log("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| log("Model loaded") | |
| log(f"Model class: {model.__class__.__name__}") | |
| log(f"Model device: {model.device}") | |
| log(f"Model dtype: {next(model.parameters()).dtype}") | |
| # ============================================================ | |
| # Config Logs | |
| # ============================================================ | |
| cfg = model.config | |
| text_cfg = getattr(cfg, "text_config", None) | |
| vision_cfg = getattr(cfg, "vision_config", None) | |
| log("========== MAIN MODEL CONFIG ==========") | |
| log(f"model_type: {getattr(cfg, 'model_type', None)}") | |
| log(f"architectures: {getattr(cfg, 'architectures', None)}") | |
| log(f"is_encoder_decoder: {getattr(cfg, 'is_encoder_decoder', None)}") | |
| log(f"text_config exists: {text_cfg is not None}") | |
| log(f"vision_config exists: {vision_cfg is not None}") | |
| log("=======================================") | |
| if text_cfg is not None: | |
| log("========== TEXT CONFIG ==========") | |
| log(f"model_type: {getattr(text_cfg, 'model_type', None)}") | |
| log(f"hidden_size: {getattr(text_cfg, 'hidden_size', None)}") | |
| log(f"intermediate_size: {getattr(text_cfg, 'intermediate_size', None)}") | |
| log(f"num_hidden_layers: {getattr(text_cfg, 'num_hidden_layers', None)}") | |
| log(f"num_attention_heads: {getattr(text_cfg, 'num_attention_heads', None)}") | |
| log(f"num_key_value_heads: {getattr(text_cfg, 'num_key_value_heads', None)}") | |
| log(f"head_dim: {getattr(text_cfg, 'head_dim', None)}") | |
| log(f"vocab_size: {getattr(text_cfg, 'vocab_size', None)}") | |
| log(f"max_position_embeddings: {getattr(text_cfg, 'max_position_embeddings', None)}") | |
| log(f"rope_theta: {getattr(text_cfg, 'rope_theta', None)}") | |
| log(f"rms_norm_eps: {getattr(text_cfg, 'rms_norm_eps', None)}") | |
| log(f"attention_bias: {getattr(text_cfg, 'attention_bias', None)}") | |
| log(f"use_cache: {getattr(text_cfg, 'use_cache', None)}") | |
| log(f"sliding_window: {getattr(text_cfg, 'sliding_window', None)}") | |
| log(f"query_pre_attn_scalar: {getattr(text_cfg, 'query_pre_attn_scalar', None)}") | |
| log(f"final_logit_softcapping: {getattr(text_cfg, 'final_logit_softcapping', None)}") | |
| log(f"attn_logit_softcapping: {getattr(text_cfg, 'attn_logit_softcapping', None)}") | |
| log("=================================") | |
| if vision_cfg is not None: | |
| log("========== VISION CONFIG ==========") | |
| log(f"model_type: {getattr(vision_cfg, 'model_type', None)}") | |
| log(f"hidden_size: {getattr(vision_cfg, 'hidden_size', None)}") | |
| log(f"intermediate_size: {getattr(vision_cfg, 'intermediate_size', None)}") | |
| log(f"num_hidden_layers: {getattr(vision_cfg, 'num_hidden_layers', None)}") | |
| log(f"num_attention_heads: {getattr(vision_cfg, 'num_attention_heads', None)}") | |
| log(f"image_size: {getattr(vision_cfg, 'image_size', None)}") | |
| log(f"patch_size: {getattr(vision_cfg, 'patch_size', None)}") | |
| log("===================================") | |
| # ============================================================ | |
| # Parameter Logs | |
| # ============================================================ | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| log(f"Total parameters: {total_params:,}") | |
| log(f"Trainable parameters: {trainable_params:,}") | |
| # ============================================================ | |
| # Module Listing | |
| # ============================================================ | |
| log("========== TEXT MODEL MODULES ==========") | |
| text_keywords = [ | |
| "language_model", | |
| "text_model", | |
| "model.layers", | |
| "self_attn", | |
| "mlp", | |
| "input_layernorm", | |
| "post_attention_layernorm", | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| "rotary", | |
| "embed_tokens", | |
| "lm_head", | |
| ] | |
| count = 0 | |
| for name, module in model.named_modules(): | |
| lower = name.lower() | |
| if "vision_tower" in lower: | |
| continue | |
| if any(k in lower for k in text_keywords): | |
| log(f"{name} => {module.__class__.__name__}") | |
| count += 1 | |
| if count >= 200: | |
| log("Stopped text module logging after 200 entries") | |
| break | |
| log(f"Text modules logged: {count}") | |
| log("========================================") | |
| # ============================================================ | |
| # Deep Forward Hooks | |
| # ============================================================ | |
| DEBUG_HOOKS = True | |
| HOOK_EVERY_N_CALLS = 20 | |
| _hook_calls = {} | |
| def tensor_stats(x): | |
| if not torch.is_tensor(x): | |
| return str(type(x)) | |
| with torch.no_grad(): | |
| xf = x.detach().float() | |
| return ( | |
| f"shape={tuple(x.shape)}, " | |
| f"dtype={x.dtype}, " | |
| f"device={x.device}, " | |
| f"mean={xf.mean().item():.5f}, " | |
| f"std={xf.std().item():.5f}, " | |
| f"min={xf.min().item():.5f}, " | |
| f"max={xf.max().item():.5f}" | |
| ) | |
| def get_first_tensor(obj): | |
| if torch.is_tensor(obj): | |
| return obj | |
| if isinstance(obj, (list, tuple)): | |
| for item in obj: | |
| t = get_first_tensor(item) | |
| if t is not None: | |
| return t | |
| if isinstance(obj, dict): | |
| for item in obj.values(): | |
| t = get_first_tensor(item) | |
| if t is not None: | |
| return t | |
| return None | |
| def make_hook(name): | |
| def hook(module, inputs, output): | |
| if not DEBUG_HOOKS: | |
| return | |
| _hook_calls[name] = _hook_calls.get(name, 0) + 1 | |
| if _hook_calls[name] % HOOK_EVERY_N_CALLS != 1: | |
| return | |
| inp = get_first_tensor(inputs) | |
| out = get_first_tensor(output) | |
| log(f"HOOK: {name}") | |
| if inp is not None: | |
| log(f" input -> {tensor_stats(inp)}") | |
| if out is not None: | |
| log(f" output -> {tensor_stats(out)}") | |
| return hook | |
| def attach_debug_hooks(): | |
| wanted = [ | |
| "self_attn", | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "mlp", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| "input_layernorm", | |
| "post_attention_layernorm", | |
| "rotary_emb", | |
| "lm_head", | |
| ] | |
| attached = 0 | |
| for name, module in model.named_modules(): | |
| lower = name.lower() | |
| if "vision_tower" in lower: | |
| continue | |
| if any(w in lower for w in wanted): | |
| module.register_forward_hook(make_hook(name)) | |
| attached += 1 | |
| log(f"Attached debug hooks: {attached}") | |
| #attach_debug_hooks() | |
| # ============================================================ | |
| # Generation Function | |
| # ============================================================ | |
| def generate_response(message, history): | |
| start_time = time.time() | |
| log("========== NEW GENERATION ==========") | |
| log(f"User message: {message}") | |
| log(f"History turns: {len(history)}") | |
| messages = [] | |
| for item in history: | |
| try: | |
| user_msg, bot_msg = item | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| except Exception as e: | |
| log(f"History parse warning: {e}") | |
| log(f"Bad history item: {item}") | |
| messages.append({"role": "user", "content": message}) | |
| log(f"Total chat messages: {len(messages)}") | |
| try: | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| return_tensors="pt", | |
| return_dict=True, | |
| add_generation_prompt=True, | |
| ).to(model.device) | |
| input_token_count = inputs["input_ids"].shape[-1] | |
| log(f"Input tensor shape: {inputs['input_ids'].shape}") | |
| log(f"Input tokens: {input_token_count}") | |
| log(f"Input device: {inputs['input_ids'].device}") | |
| log("========== TOKEN DEBUG ==========") | |
| ids = inputs["input_ids"][0].tolist() | |
| log(f"First 20 token ids: {ids[:20]}") | |
| log(f"Last 20 token ids: {ids[-20:]}") | |
| log(f"Decoded prompt preview: {tokenizer.decode(ids[-200:], skip_special_tokens=False)}") | |
| log("=================================") | |
| except Exception as e: | |
| log("Chat template/tokenization failed") | |
| log(traceback.format_exc()) | |
| yield f"Tokenization error: {e}" | |
| return | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| timeout=420.0, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| generate_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=1024, | |
| temperature=0.7, | |
| do_sample=False, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| log("Generation kwargs:") | |
| log("max_new_tokens=1024") | |
| log("temperature=0.7") | |
| log("do_sample=True") | |
| log("top_p=0.9") | |
| def run_generation(): | |
| try: | |
| log("Generation thread started") | |
| gen_start = time.time() | |
| with torch.no_grad(): | |
| model.generate(**generate_kwargs) | |
| gen_time = time.time() - gen_start | |
| log(f"Generation thread finished in {gen_time:.2f}s") | |
| except Exception as e: | |
| log("Generation Error") | |
| log(traceback.format_exc()) | |
| streamer.text_queue.put( | |
| f"\n[Generation thread crashed. Reason: {e}]" | |
| ) | |
| streamer.end() | |
| t = Thread(target=run_generation) | |
| t.start() | |
| partial_text = "" | |
| token_chunks = 0 | |
| try: | |
| for new_text in streamer: | |
| token_chunks += 1 | |
| partial_text += new_text | |
| if token_chunks % 20 == 0: | |
| elapsed = time.time() - start_time | |
| log( | |
| f"Streaming chunks: {token_chunks}, " | |
| f"chars: {len(partial_text)}, " | |
| f"elapsed: {elapsed:.2f}s" | |
| ) | |
| yield partial_text | |
| except Exception as e: | |
| log("Streaming Error") | |
| log(traceback.format_exc()) | |
| yield partial_text + f"\n\n[Streaming error: {e}]" | |
| finally: | |
| elapsed = time.time() - start_time | |
| log("========== GENERATION DONE ==========") | |
| log(f"Output chars: {len(partial_text)}") | |
| log(f"Streaming chunks: {token_chunks}") | |
| log(f"Elapsed seconds: {elapsed:.2f}") | |
| log("=====================================") | |
| # ============================================================ | |
| # Gradio UI | |
| # ============================================================ | |
| demo = gr.ChatInterface( | |
| fn=generate_response, | |
| title="Gemma 4 E4B - Debug", | |
| examples=[ | |
| "Explain quantum entanglement simply.", | |
| "Write a Python function to add two numbers.", | |
| "Explain how RoPE works in transformer attention.", | |
| ], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| log("Launching Gradio app...") | |
| demo.launch() |