Spaces:
Sleeping
Sleeping
File size: 2,599 Bytes
657dabc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import gradio as gr
import torch
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from model import LlamaForCausalLM # Import your custom model class
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]"
# Initialize model with reduced parameters (135M config)
model = LlamaForCausalLM(
vocab_size=tokenizer.vocab_size,
dim=576,
num_layers=30,
hidden_dim=1536,
num_heads=9
)
device = "cpu"
# Load trained weights
# state_dict = torch.hub.load_state_dict_from_url(
# "https://huggingface.co/Rajendro/smallmv2135/blob/main/model-dict-step-5500.pt",
# map_location="cpu"
# )
# model.load_state_dict(state_dict)
# model.eval()
model_id = "Rajendro/smallmv2135"
checkpoint_path = hf_hub_download(repo_id=model_id, filename="model-dict-step-5500.pt")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids)
next_token_logits = outputs[:, -1, :] / temperature
# Apply top-k sampling
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
probs = torch.softmax(top_k_logits, dim=-1)
# Sample from distribution
next_token_idx = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices[0, next_token_idx[0]]
if next_token.item() == tokenizer.eos_token_id:
break
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
# Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Input Prompt", lines=3),
gr.Slider(50, 200, value=100, label="Max Length"),
gr.Slider(0.1, 2.0, value=0.7, label="Temperature"),
gr.Slider(10, 100, value=50, label="Top-k")
],
outputs=gr.Textbox(label="Generated Text", lines=5),
title="🦙 Sample SmolLLM Demo",
description="A 135M parameter language model trained on smollm-corpus"
)
if __name__ == "__main__":
demo.launch() |