GPT-2-450M / app.py
nnsohamnn's picture
Update app.py
8cbc34c verified
import gradio as gr
import torch
from transformers import AutoTokenizer
from safetensors.torch import load_model
import os
import re
from huggingface_hub import hf_hub_download, HfApi
import gc
import time
import torch.nn.functional as F
# Import your custom model
from model_architecture import GPT2Config, GPT2Model
# ===========================
# GLOBAL SETTINGS
# ===========================
torch.set_grad_enabled(False)
# Constants
REPO_ID = "nnsohamnn/gpt2-450M-fineweb"
BASE_MODEL = "gpt2"
CACHE_DIR = "./model_cache"
# Global variables
current_model = None
current_tokenizer = None
current_device = "cpu"
stop_generation = False
os.makedirs(CACHE_DIR, exist_ok=True)
# =========================================
# HELPER FUNCTIONS
# =========================================
def get_available_models():
"""Fetch available safetensors from HF repo"""
try:
api = HfApi()
repo_files = api.list_repo_files(repo_id=REPO_ID, repo_type="model")
model_files = [f for f in repo_files if f.endswith(".safetensors")]
def _key(x):
m = re.search(r'step_(\d+)', x)
return int(m.group(1)) if m else 0
model_files.sort(key=_key)
return model_files
except Exception as e:
print(f"Error fetching models: {e}")
return []
def get_model_cache_path(model_name):
"""Get cache path for model"""
safe_name = re.sub(r'[^\w\-_\.]', '_', model_name)
return os.path.join(CACHE_DIR, safe_name)
def is_model_cached(model_name):
"""Check if model is cached"""
return os.path.exists(get_model_cache_path(model_name))
def download_model(model_name):
"""Download model if not cached"""
cache_path = get_model_cache_path(model_name)
if not is_model_cached(model_name):
print(f"Downloading {model_name}...")
try:
downloaded_path = hf_hub_download(
repo_id=REPO_ID,
filename=model_name,
cache_dir=CACHE_DIR
)
return downloaded_path
except Exception as e:
print(f"Error downloading: {e}")
return None
return cache_path
def load_model_checkpoint(model_name):
"""Load model with torch.compile optimization"""
global current_model, current_tokenizer, current_device
try:
checkpoint_path = download_model(model_name)
if checkpoint_path is None:
return "Failed to download model"
# Build model instance
config = GPT2Config()
model = GPT2Model(config)
# Load weights
load_model(model, checkpoint_path, device="cpu")
# Ensure tied weights
try:
model.lm_head.weight = model.embed_tokens.weight
except Exception:
pass
model = model.to("cpu")
model.eval()
# Apply torch.compile for faster inference
# if hasattr(torch, 'compile'):
# try:
# model = torch.compile(model, mode="reduce-overhead")
# print("βœ… Model compiled with torch.compile")
# except Exception as e:
# print(f"⚠️ torch.compile failed: {e}, using eager mode")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
# Update globals
current_model = model
current_tokenizer = tokenizer
current_device = "cpu"
return f"βœ… Model loaded: {model_name}"
except Exception as e:
return f"❌ Error: {str(e)}"
# =========================================
# GENERATION (FIXED: Streaming + No Format Labels)
# =========================================
def generate_text_streaming(prompt, max_tokens=100, temperature=0.7, top_p=0.9):
"""Fixed generation with token-by-token streaming and format stripping"""
global stop_generation, current_model, current_tokenizer, current_device
if current_model is None or current_tokenizer is None:
yield "⚠️ No model loaded"
return
stop_generation = False
repetition_penalty = 1.1
frequency_penalty = 0.1
try:
# Encode prompt (includes "User: ... Assistant:")
input_ids = current_tokenizer.encode(prompt, return_tensors="pt").to(current_device)
generated = input_ids.clone()
generated_tokens = generated[0].tolist()
start_time = time.time()
token_count = 0
with torch.inference_mode():
for _ in range(max_tokens):
if stop_generation:
break
logits = current_model(generated)
next_token_logits = logits[:, -1, :].clone()
# 1. Repetition penalty
for token_id in set(generated_tokens):
if 0 <= token_id < next_token_logits.shape[-1]:
if next_token_logits[0, token_id] < 0:
next_token_logits[0, token_id] *= repetition_penalty
else:
next_token_logits[0, token_id] /= repetition_penalty
# 2. Frequency penalty
token_counts = {}
for token_id in generated_tokens:
token_counts[token_id] = token_counts.get(token_id, 0) + 1
for token_id, count in token_counts.items():
if 0 <= token_id < next_token_logits.shape[-1]:
next_token_logits[0, token_id] -= frequency_penalty * count
# 3. Temperature
next_token_logits = next_token_logits / max(temperature, 0.1)
# 4. Top-k
top_k = 50
if next_token_logits.shape[-1] > top_k:
top_k_values, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
full_logits = torch.full_like(next_token_logits, -float('Inf'))
full_logits.scatter_(-1, top_k_indices, top_k_values)
next_token_logits = full_logits
# 5. Top-p
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
indices_to_remove = cumulative_probs > top_p
indices_to_remove[..., 0] = False
if indices_to_remove.any():
mask_indices = sorted_indices[0, indices_to_remove[0]]
next_token_logits[0, mask_indices] = -float('Inf')
# 6. Sample
probs = F.softmax(next_token_logits, dim=-1)
if torch.isnan(probs).any() or probs.sum() <= 0:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)[:, -1].unsqueeze(-1)
else:
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=-1)
generated_tokens.append(next_token.item())
token_count += 1
# FIXED: Decode EVERY token for true streaming
full_decoded = current_tokenizer.decode(generated[0], skip_special_tokens=True)
# FIXED: Strip format labels (User:/Assistant:)
if "Assistant:" in full_decoded:
response_text = full_decoded.split("Assistant:")[-1].strip()
else:
response_text = full_decoded
# FIXED: Stop if model generates "User:" (multi-turn hallucination)
if "User:" in response_text:
response_text = response_text.split("User:")[0].strip()
elapsed = time.time() - start_time
speed = token_count / elapsed if elapsed > 0 else 0.0
yield f"{response_text}\n\n---\nβœ… {token_count} tokens in {elapsed:.1f}s ({speed:.1f} tok/s)"
break
# Yield clean response (streaming every token)
elapsed = time.time() - start_time
speed = token_count / elapsed if elapsed > 0 else 0.0
yield f"{response_text}\n\n---\n⚑ {speed:.1f} tok/s | {token_count}/{max_tokens}"
# Stop at EOS
if next_token.item() == current_tokenizer.eos_token_id:
break
# Final output
final_decoded = current_tokenizer.decode(generated[0], skip_special_tokens=True)
if "Assistant:" in final_decoded:
final_text = final_decoded.split("Assistant:")[-1].strip()
else:
final_text = final_decoded
if "User:" in final_text:
final_text = final_text.split("User:")[0].strip()
elapsed = time.time() - start_time
yield f"{final_text}\n\n---\nβœ… {token_count} tokens in {elapsed:.1f}s ({(token_count/elapsed) if elapsed>0 else 0:.1f} tok/s)"
except Exception as e:
yield f"❌ Error: {str(e)}"
def stop_generation_func():
"""Stop generation"""
global stop_generation
stop_generation = True
return "πŸ›‘ Stopped"
# =========================================
# GRADIO INTERFACE
# =========================================
def create_interface():
initial_models = get_available_models()
with gr.Blocks(title="GPT-2 FineWeb Chat πŸ’¬", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ€– GPT-2 450M Chat")
gr.Markdown("Custom GPT-2 trained on FineWebEdu+SmolTalk (FineWebEdu(50k) + SmolTalk(2k steps))")
# Added Hugging Face repo link
gr.Markdown("Repo with training details and weight [here](https://huggingface.co/nnsohamnn/gpt2-450M-fineweb)")
with gr.Row():
# Sidebar - Model Selection
with gr.Column(scale=1):
gr.Markdown("### Model Settings")
model_dropdown = gr.Dropdown(
choices=initial_models,
value=initial_models[-1] if initial_models else None,
label="πŸ“ Select Checkpoint",
interactive=True
)
load_btn = gr.Button("Load Model", variant="primary")
model_status = gr.Textbox(label="Status", interactive=False, lines=2)
gr.Markdown("### Generation Parameters")
max_tokens = gr.Slider(10, 300, value=100, step=10, label="Max Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.1, value=0.9, step=0.05, label="Top-p")
stop_btn = gr.Button("πŸ›‘ Stop Generation", variant="stop")
refresh_btn = gr.Button("πŸ”„ Refresh Model List")
# Main Chat Area
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Chat",
height=500,
type="messages"
)
with gr.Row():
msg = gr.Textbox(
label="Message",
placeholder="Type your message here...",
lines=2,
scale=4
)
send_btn = gr.Button("Send", variant="primary", scale=1)
clear_btn = gr.Button("Clear Chat")
gr.Examples(
examples=[
"Hello!",
"What is artificial intelligence?",
"Explain quantum computing in simple terms",
],
inputs=[msg]
)
# Event handlers
def user_submit(user_message, chat_history):
"""Handle user message submission"""
if not user_message:
return "", chat_history
chat_history.append({"role": "user", "content": user_message})
return "", chat_history
def bot_response(chat_history, max_tok, temp, top):
"""Generate bot response with conversational format (User:/Assistant:)"""
if not chat_history or chat_history[-1]["role"] != "user":
return chat_history
user_msg = chat_history[-1]["content"]
# Build conversation history in format: User: ... \nAssistant: ... \n
conversation_context = ""
for msg in chat_history[:-1]: # All previous messages
if msg["role"] == "user":
conversation_context += f"User: {msg['content']}\n"
elif msg["role"] == "assistant":
# Extract just text (remove stats footer)
assistant_text = msg["content"].split("\n\n---\n")[0] if "\n\n---\n" in msg["content"] else msg["content"]
conversation_context += f"Assistant: {assistant_text}\n"
# Add current user message with format
prompt = f"{conversation_context}User: {user_msg}\nAssistant:"
# Initialize assistant response
chat_history.append({"role": "assistant", "content": ""})
# Stream generation
for response in generate_text_streaming(prompt, max_tok, temp, top):
chat_history[-1]["content"] = response
yield chat_history
# Submit message
msg.submit(
fn=user_submit,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
).then(
fn=bot_response,
inputs=[chatbot, max_tokens, temperature, top_p],
outputs=chatbot
)
send_btn.click(
fn=user_submit,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
).then(
fn=bot_response,
inputs=[chatbot, max_tokens, temperature, top_p],
outputs=chatbot
)
# Clear chat
clear_btn.click(fn=lambda: [], outputs=chatbot)
# Load model
load_btn.click(
fn=load_model_checkpoint,
inputs=[model_dropdown],
outputs=model_status
)
# Stop generation
stop_btn.click(fn=stop_generation_func, outputs=model_status)
# Refresh models
def refresh_models():
models = get_available_models()
return gr.Dropdown(choices=models, value=models[-1] if models else None), f"Found {len(models)} models"
refresh_btn.click(
fn=refresh_models,
outputs=[model_dropdown, model_status]
)
# Auto-load first model on startup
demo.load(
fn=lambda: load_model_checkpoint(initial_models[-1]) if initial_models else "No models found",
outputs=model_status
)
return demo
# Launch
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)