Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer | |
| from safetensors.torch import load_model | |
| import os | |
| import re | |
| from huggingface_hub import hf_hub_download, HfApi | |
| import gc | |
| import time | |
| import torch.nn.functional as F | |
| # Import your custom model | |
| from model_architecture import GPT2Config, GPT2Model | |
| # =========================== | |
| # GLOBAL SETTINGS | |
| # =========================== | |
| torch.set_grad_enabled(False) | |
| # Constants | |
| REPO_ID = "nnsohamnn/gpt2-450M-fineweb" | |
| BASE_MODEL = "gpt2" | |
| CACHE_DIR = "./model_cache" | |
| # Global variables | |
| current_model = None | |
| current_tokenizer = None | |
| current_device = "cpu" | |
| stop_generation = False | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # ========================================= | |
| # HELPER FUNCTIONS | |
| # ========================================= | |
| def get_available_models(): | |
| """Fetch available safetensors from HF repo""" | |
| try: | |
| api = HfApi() | |
| repo_files = api.list_repo_files(repo_id=REPO_ID, repo_type="model") | |
| model_files = [f for f in repo_files if f.endswith(".safetensors")] | |
| def _key(x): | |
| m = re.search(r'step_(\d+)', x) | |
| return int(m.group(1)) if m else 0 | |
| model_files.sort(key=_key) | |
| return model_files | |
| except Exception as e: | |
| print(f"Error fetching models: {e}") | |
| return [] | |
| def get_model_cache_path(model_name): | |
| """Get cache path for model""" | |
| safe_name = re.sub(r'[^\w\-_\.]', '_', model_name) | |
| return os.path.join(CACHE_DIR, safe_name) | |
| def is_model_cached(model_name): | |
| """Check if model is cached""" | |
| return os.path.exists(get_model_cache_path(model_name)) | |
| def download_model(model_name): | |
| """Download model if not cached""" | |
| cache_path = get_model_cache_path(model_name) | |
| if not is_model_cached(model_name): | |
| print(f"Downloading {model_name}...") | |
| try: | |
| downloaded_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=model_name, | |
| cache_dir=CACHE_DIR | |
| ) | |
| return downloaded_path | |
| except Exception as e: | |
| print(f"Error downloading: {e}") | |
| return None | |
| return cache_path | |
| def load_model_checkpoint(model_name): | |
| """Load model with torch.compile optimization""" | |
| global current_model, current_tokenizer, current_device | |
| try: | |
| checkpoint_path = download_model(model_name) | |
| if checkpoint_path is None: | |
| return "Failed to download model" | |
| # Build model instance | |
| config = GPT2Config() | |
| model = GPT2Model(config) | |
| # Load weights | |
| load_model(model, checkpoint_path, device="cpu") | |
| # Ensure tied weights | |
| try: | |
| model.lm_head.weight = model.embed_tokens.weight | |
| except Exception: | |
| pass | |
| model = model.to("cpu") | |
| model.eval() | |
| # Apply torch.compile for faster inference | |
| # if hasattr(torch, 'compile'): | |
| # try: | |
| # model = torch.compile(model, mode="reduce-overhead") | |
| # print("β Model compiled with torch.compile") | |
| # except Exception as e: | |
| # print(f"β οΈ torch.compile failed: {e}, using eager mode") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Update globals | |
| current_model = model | |
| current_tokenizer = tokenizer | |
| current_device = "cpu" | |
| return f"β Model loaded: {model_name}" | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| # ========================================= | |
| # GENERATION (FIXED: Streaming + No Format Labels) | |
| # ========================================= | |
| def generate_text_streaming(prompt, max_tokens=100, temperature=0.7, top_p=0.9): | |
| """Fixed generation with token-by-token streaming and format stripping""" | |
| global stop_generation, current_model, current_tokenizer, current_device | |
| if current_model is None or current_tokenizer is None: | |
| yield "β οΈ No model loaded" | |
| return | |
| stop_generation = False | |
| repetition_penalty = 1.1 | |
| frequency_penalty = 0.1 | |
| try: | |
| # Encode prompt (includes "User: ... Assistant:") | |
| input_ids = current_tokenizer.encode(prompt, return_tensors="pt").to(current_device) | |
| generated = input_ids.clone() | |
| generated_tokens = generated[0].tolist() | |
| start_time = time.time() | |
| token_count = 0 | |
| with torch.inference_mode(): | |
| for _ in range(max_tokens): | |
| if stop_generation: | |
| break | |
| logits = current_model(generated) | |
| next_token_logits = logits[:, -1, :].clone() | |
| # 1. Repetition penalty | |
| for token_id in set(generated_tokens): | |
| if 0 <= token_id < next_token_logits.shape[-1]: | |
| if next_token_logits[0, token_id] < 0: | |
| next_token_logits[0, token_id] *= repetition_penalty | |
| else: | |
| next_token_logits[0, token_id] /= repetition_penalty | |
| # 2. Frequency penalty | |
| token_counts = {} | |
| for token_id in generated_tokens: | |
| token_counts[token_id] = token_counts.get(token_id, 0) + 1 | |
| for token_id, count in token_counts.items(): | |
| if 0 <= token_id < next_token_logits.shape[-1]: | |
| next_token_logits[0, token_id] -= frequency_penalty * count | |
| # 3. Temperature | |
| next_token_logits = next_token_logits / max(temperature, 0.1) | |
| # 4. Top-k | |
| top_k = 50 | |
| if next_token_logits.shape[-1] > top_k: | |
| top_k_values, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1) | |
| full_logits = torch.full_like(next_token_logits, -float('Inf')) | |
| full_logits.scatter_(-1, top_k_indices, top_k_values) | |
| next_token_logits = full_logits | |
| # 5. Top-p | |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True, dim=-1) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| indices_to_remove = cumulative_probs > top_p | |
| indices_to_remove[..., 0] = False | |
| if indices_to_remove.any(): | |
| mask_indices = sorted_indices[0, indices_to_remove[0]] | |
| next_token_logits[0, mask_indices] = -float('Inf') | |
| # 6. Sample | |
| probs = F.softmax(next_token_logits, dim=-1) | |
| if torch.isnan(probs).any() or probs.sum() <= 0: | |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)[:, -1].unsqueeze(-1) | |
| else: | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated = torch.cat([generated, next_token], dim=-1) | |
| generated_tokens.append(next_token.item()) | |
| token_count += 1 | |
| # FIXED: Decode EVERY token for true streaming | |
| full_decoded = current_tokenizer.decode(generated[0], skip_special_tokens=True) | |
| # FIXED: Strip format labels (User:/Assistant:) | |
| if "Assistant:" in full_decoded: | |
| response_text = full_decoded.split("Assistant:")[-1].strip() | |
| else: | |
| response_text = full_decoded | |
| # FIXED: Stop if model generates "User:" (multi-turn hallucination) | |
| if "User:" in response_text: | |
| response_text = response_text.split("User:")[0].strip() | |
| elapsed = time.time() - start_time | |
| speed = token_count / elapsed if elapsed > 0 else 0.0 | |
| yield f"{response_text}\n\n---\nβ {token_count} tokens in {elapsed:.1f}s ({speed:.1f} tok/s)" | |
| break | |
| # Yield clean response (streaming every token) | |
| elapsed = time.time() - start_time | |
| speed = token_count / elapsed if elapsed > 0 else 0.0 | |
| yield f"{response_text}\n\n---\nβ‘ {speed:.1f} tok/s | {token_count}/{max_tokens}" | |
| # Stop at EOS | |
| if next_token.item() == current_tokenizer.eos_token_id: | |
| break | |
| # Final output | |
| final_decoded = current_tokenizer.decode(generated[0], skip_special_tokens=True) | |
| if "Assistant:" in final_decoded: | |
| final_text = final_decoded.split("Assistant:")[-1].strip() | |
| else: | |
| final_text = final_decoded | |
| if "User:" in final_text: | |
| final_text = final_text.split("User:")[0].strip() | |
| elapsed = time.time() - start_time | |
| yield f"{final_text}\n\n---\nβ {token_count} tokens in {elapsed:.1f}s ({(token_count/elapsed) if elapsed>0 else 0:.1f} tok/s)" | |
| except Exception as e: | |
| yield f"β Error: {str(e)}" | |
| def stop_generation_func(): | |
| """Stop generation""" | |
| global stop_generation | |
| stop_generation = True | |
| return "π Stopped" | |
| # ========================================= | |
| # GRADIO INTERFACE | |
| # ========================================= | |
| def create_interface(): | |
| initial_models = get_available_models() | |
| with gr.Blocks(title="GPT-2 FineWeb Chat π¬", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π€ GPT-2 450M Chat") | |
| gr.Markdown("Custom GPT-2 trained on FineWebEdu+SmolTalk (FineWebEdu(50k) + SmolTalk(2k steps))") | |
| # Added Hugging Face repo link | |
| gr.Markdown("Repo with training details and weight [here](https://huggingface.co/nnsohamnn/gpt2-450M-fineweb)") | |
| with gr.Row(): | |
| # Sidebar - Model Selection | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model Settings") | |
| model_dropdown = gr.Dropdown( | |
| choices=initial_models, | |
| value=initial_models[-1] if initial_models else None, | |
| label="π Select Checkpoint", | |
| interactive=True | |
| ) | |
| load_btn = gr.Button("Load Model", variant="primary") | |
| model_status = gr.Textbox(label="Status", interactive=False, lines=2) | |
| gr.Markdown("### Generation Parameters") | |
| max_tokens = gr.Slider(10, 300, value=100, step=10, label="Max Tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.1, value=0.9, step=0.05, label="Top-p") | |
| stop_btn = gr.Button("π Stop Generation", variant="stop") | |
| refresh_btn = gr.Button("π Refresh Model List") | |
| # Main Chat Area | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| height=500, | |
| type="messages" | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Type your message here...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| clear_btn = gr.Button("Clear Chat") | |
| gr.Examples( | |
| examples=[ | |
| "Hello!", | |
| "What is artificial intelligence?", | |
| "Explain quantum computing in simple terms", | |
| ], | |
| inputs=[msg] | |
| ) | |
| # Event handlers | |
| def user_submit(user_message, chat_history): | |
| """Handle user message submission""" | |
| if not user_message: | |
| return "", chat_history | |
| chat_history.append({"role": "user", "content": user_message}) | |
| return "", chat_history | |
| def bot_response(chat_history, max_tok, temp, top): | |
| """Generate bot response with conversational format (User:/Assistant:)""" | |
| if not chat_history or chat_history[-1]["role"] != "user": | |
| return chat_history | |
| user_msg = chat_history[-1]["content"] | |
| # Build conversation history in format: User: ... \nAssistant: ... \n | |
| conversation_context = "" | |
| for msg in chat_history[:-1]: # All previous messages | |
| if msg["role"] == "user": | |
| conversation_context += f"User: {msg['content']}\n" | |
| elif msg["role"] == "assistant": | |
| # Extract just text (remove stats footer) | |
| assistant_text = msg["content"].split("\n\n---\n")[0] if "\n\n---\n" in msg["content"] else msg["content"] | |
| conversation_context += f"Assistant: {assistant_text}\n" | |
| # Add current user message with format | |
| prompt = f"{conversation_context}User: {user_msg}\nAssistant:" | |
| # Initialize assistant response | |
| chat_history.append({"role": "assistant", "content": ""}) | |
| # Stream generation | |
| for response in generate_text_streaming(prompt, max_tok, temp, top): | |
| chat_history[-1]["content"] = response | |
| yield chat_history | |
| # Submit message | |
| msg.submit( | |
| fn=user_submit, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot] | |
| ).then( | |
| fn=bot_response, | |
| inputs=[chatbot, max_tokens, temperature, top_p], | |
| outputs=chatbot | |
| ) | |
| send_btn.click( | |
| fn=user_submit, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot] | |
| ).then( | |
| fn=bot_response, | |
| inputs=[chatbot, max_tokens, temperature, top_p], | |
| outputs=chatbot | |
| ) | |
| # Clear chat | |
| clear_btn.click(fn=lambda: [], outputs=chatbot) | |
| # Load model | |
| load_btn.click( | |
| fn=load_model_checkpoint, | |
| inputs=[model_dropdown], | |
| outputs=model_status | |
| ) | |
| # Stop generation | |
| stop_btn.click(fn=stop_generation_func, outputs=model_status) | |
| # Refresh models | |
| def refresh_models(): | |
| models = get_available_models() | |
| return gr.Dropdown(choices=models, value=models[-1] if models else None), f"Found {len(models)} models" | |
| refresh_btn.click( | |
| fn=refresh_models, | |
| outputs=[model_dropdown, model_status] | |
| ) | |
| # Auto-load first model on startup | |
| demo.load( | |
| fn=lambda: load_model_checkpoint(initial_models[-1]) if initial_models else "No models found", | |
| outputs=model_status | |
| ) | |
| return demo | |
| # Launch | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=True) | |