jonmabe commited on
Commit
3aa6f7b
·
verified ·
1 Parent(s): 80765dc

Initial Gradio demo upload

Browse files
Files changed (4) hide show
  1. README.md +45 -5
  2. app.py +254 -0
  3. model.py +268 -0
  4. requirements.txt +4 -0
README.md CHANGED
@@ -1,12 +1,52 @@
1
  ---
2
- title: Tiny Llm Demo
3
- emoji: 💻
4
  colorFrom: blue
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Tiny-LLM Text Generator
3
+ emoji: 🤖
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ # Tiny-LLM Text Generator
14
+
15
+ A **54 million parameter** language model trained **from scratch** on Wikipedia.
16
+
17
+ ## About
18
+
19
+ This demonstrates that meaningful language models can be trained on consumer hardware with modest compute budgets!
20
+
21
+ ## Architecture
22
+
23
+ | Component | Value |
24
+ |-----------|-------|
25
+ | Parameters | 54.93M |
26
+ | Layers | 12 |
27
+ | Hidden Size | 512 |
28
+ | Attention Heads | 8 |
29
+ | Intermediate (FFN) | 1408 |
30
+ | Vocab Size | 32,000 |
31
+ | Max Sequence Length | 512 |
32
+ | Position Encoding | RoPE |
33
+ | Normalization | RMSNorm |
34
+ | Activation | SwiGLU |
35
+
36
+ ## Training
37
+
38
+ - **Training Steps**: 50,000
39
+ - **Tokens**: ~100M
40
+ - **Hardware**: NVIDIA RTX 5090 (32GB)
41
+ - **Training Time**: ~3 hours
42
+
43
+ ## Model
44
+
45
+ [jonmabe/tiny-llm-54m](https://huggingface.co/jonmabe/tiny-llm-54m)
46
+
47
+ ## Limitations
48
+
49
+ - Small model size limits knowledge and capabilities
50
+ - Trained only on Wikipedia - limited domain coverage
51
+ - May generate factually incorrect information
52
+ - Not instruction-tuned
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tiny-LLM Demo - Text Generation with a 54M Parameter Model
3
+
4
+ This model was trained from scratch on Wikipedia data.
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from huggingface_hub import hf_hub_download
10
+ from model import TinyLLM, MODEL_CONFIG
11
+
12
+ # Model configuration
13
+ MODEL_ID = "jonmabe/tiny-llm-54m"
14
+ MODEL_FILENAME = "final_model.pt"
15
+
16
+ # Try to use transformers tokenizer, fall back to simple tokenizer
17
+ try:
18
+ from transformers import AutoTokenizer
19
+ # Try to load from model repo, fall back to GPT-2 tokenizer
20
+ try:
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
+ except:
23
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
24
+ USE_HF_TOKENIZER = True
25
+ except Exception as e:
26
+ print(f"Could not load HuggingFace tokenizer: {e}")
27
+ USE_HF_TOKENIZER = False
28
+ tokenizer = None
29
+
30
+ # Load model
31
+ print("Downloading model...")
32
+ model_path = hf_hub_download(repo_id=MODEL_ID, filename=MODEL_FILENAME)
33
+ print(f"Model downloaded to {model_path}")
34
+
35
+ print("Loading model...")
36
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
37
+
38
+ # Get config from checkpoint if available
39
+ if "config" in checkpoint and isinstance(checkpoint["config"], dict):
40
+ config = checkpoint["config"]
41
+ if "model" in config:
42
+ config = config["model"]
43
+ else:
44
+ config = MODEL_CONFIG
45
+
46
+ # Initialize model
47
+ model = TinyLLM(config)
48
+
49
+ # Load weights
50
+ if "model_state_dict" in checkpoint:
51
+ state_dict = checkpoint["model_state_dict"]
52
+ else:
53
+ state_dict = checkpoint
54
+
55
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
56
+ if missing:
57
+ print(f"Warning: Missing keys: {missing[:5]}...")
58
+ if unexpected:
59
+ print(f"Warning: Unexpected keys: {unexpected[:5]}...")
60
+
61
+ # Move to device
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+ model = model.to(device)
64
+ model.eval()
65
+
66
+ total_params = sum(p.numel() for p in model.parameters())
67
+ print(f"Model loaded on {device} with {total_params:,} parameters")
68
+
69
+
70
+ def generate_text(
71
+ prompt: str,
72
+ max_tokens: int = 100,
73
+ temperature: float = 0.8,
74
+ top_p: float = 0.9,
75
+ top_k: int = 50,
76
+ repetition_penalty: float = 1.1,
77
+ ) -> str:
78
+ """Generate text continuation from a prompt."""
79
+
80
+ if not prompt.strip():
81
+ return "Please enter a prompt to generate text."
82
+
83
+ # Tokenize
84
+ if USE_HF_TOKENIZER and tokenizer is not None:
85
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
86
+ eos_token_id = tokenizer.eos_token_id
87
+ else:
88
+ # Simple fallback - won't work well but better than crashing
89
+ return "Tokenizer not available. Please ensure transformers is installed."
90
+
91
+ # Generate
92
+ with torch.no_grad():
93
+ output_ids = model.generate(
94
+ input_ids,
95
+ max_new_tokens=max_tokens,
96
+ temperature=temperature,
97
+ top_p=top_p,
98
+ top_k=top_k,
99
+ repetition_penalty=repetition_penalty,
100
+ eos_token_id=eos_token_id,
101
+ )
102
+
103
+ # Decode
104
+ if USE_HF_TOKENIZER and tokenizer is not None:
105
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
106
+ else:
107
+ generated_text = "Decoding not available."
108
+
109
+ return generated_text
110
+
111
+
112
+ # Example prompts
113
+ EXAMPLES = [
114
+ ["The history of artificial intelligence began"],
115
+ ["In the year 2050, humanity"],
116
+ ["The most important scientific discovery was"],
117
+ ["Once upon a time, in a kingdom far away"],
118
+ ["The universe is vast and"],
119
+ ["Climate change affects"],
120
+ ["The theory of relativity states that"],
121
+ ["In ancient Rome,"],
122
+ ]
123
+
124
+
125
+ # Create Gradio interface
126
+ with gr.Blocks(title="Tiny-LLM Text Generator") as demo:
127
+ gr.Markdown("""
128
+ # 🤖 Tiny-LLM Text Generator
129
+
130
+ A **54 million parameter** language model trained **from scratch** on Wikipedia.
131
+
132
+ This demonstrates that meaningful language models can be trained on consumer hardware!
133
+
134
+ ### Architecture
135
+ - **Parameters**: 54.93M
136
+ - **Layers**: 12
137
+ - **Hidden Size**: 512
138
+ - **Attention Heads**: 8
139
+ - **Position Encoding**: RoPE
140
+ - **Normalization**: RMSNorm
141
+ - **Activation**: SwiGLU
142
+ """)
143
+
144
+ with gr.Row():
145
+ with gr.Column(scale=2):
146
+ prompt_input = gr.Textbox(
147
+ label="Prompt",
148
+ placeholder="Enter your prompt here...",
149
+ lines=3,
150
+ value="The history of artificial intelligence began"
151
+ )
152
+
153
+ with gr.Row():
154
+ with gr.Column():
155
+ max_tokens = gr.Slider(
156
+ minimum=10,
157
+ maximum=256,
158
+ value=100,
159
+ step=10,
160
+ label="Max New Tokens",
161
+ )
162
+ temperature = gr.Slider(
163
+ minimum=0.1,
164
+ maximum=2.0,
165
+ value=0.8,
166
+ step=0.1,
167
+ label="Temperature",
168
+ info="Higher = more random"
169
+ )
170
+
171
+ with gr.Column():
172
+ top_p = gr.Slider(
173
+ minimum=0.1,
174
+ maximum=1.0,
175
+ value=0.9,
176
+ step=0.05,
177
+ label="Top-p (Nucleus Sampling)",
178
+ )
179
+ top_k = gr.Slider(
180
+ minimum=1,
181
+ maximum=100,
182
+ value=50,
183
+ step=5,
184
+ label="Top-k",
185
+ )
186
+
187
+ repetition_penalty = gr.Slider(
188
+ minimum=1.0,
189
+ maximum=2.0,
190
+ value=1.1,
191
+ step=0.05,
192
+ label="Repetition Penalty",
193
+ info="Higher = less repetition"
194
+ )
195
+
196
+ generate_btn = gr.Button("✨ Generate", variant="primary", size="lg")
197
+
198
+ with gr.Column(scale=2):
199
+ output_text = gr.Textbox(
200
+ label="Generated Text",
201
+ lines=15,
202
+ interactive=False,
203
+ )
204
+
205
+ gr.Markdown("### 📝 Example Prompts")
206
+ gr.Examples(
207
+ examples=EXAMPLES,
208
+ inputs=prompt_input,
209
+ )
210
+
211
+ # Event handlers
212
+ generate_btn.click(
213
+ fn=generate_text,
214
+ inputs=[prompt_input, max_tokens, temperature, top_p, top_k, repetition_penalty],
215
+ outputs=output_text,
216
+ )
217
+
218
+ prompt_input.submit(
219
+ fn=generate_text,
220
+ inputs=[prompt_input, max_tokens, temperature, top_p, top_k, repetition_penalty],
221
+ outputs=output_text,
222
+ )
223
+
224
+ gr.Markdown("""
225
+ ---
226
+ ### About This Model
227
+
228
+ **Model**: [jonmabe/tiny-llm-54m](https://huggingface.co/jonmabe/tiny-llm-54m)
229
+
230
+ This is a decoder-only transformer trained from scratch on Wikipedia text.
231
+ It demonstrates that meaningful language models can be trained on consumer hardware
232
+ with modest compute budgets (~3 hours on an RTX 5090).
233
+
234
+ #### Training Details
235
+ - **Training Steps**: 50,000
236
+ - **Tokens**: ~100M
237
+ - **Hardware**: NVIDIA RTX 5090 (32GB)
238
+ - **Training Time**: ~3 hours
239
+
240
+ #### Limitations
241
+ - Small model size limits knowledge and capabilities
242
+ - Trained only on Wikipedia - limited domain coverage
243
+ - May generate factually incorrect information
244
+ - Not instruction-tuned
245
+
246
+ #### Intended Use
247
+ - Educational: Understanding transformer training
248
+ - Experimental: Testing fine-tuning approaches
249
+ - Research: Lightweight model for NLP experiments
250
+ """)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ demo.launch()
model.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyLLM Model Architecture
3
+
4
+ A small transformer language model (~54.93M parameters) trained from scratch.
5
+ Architecture:
6
+ - 12 layers
7
+ - 512 hidden size
8
+ - 8 attention heads
9
+ - 1408 intermediate (FFN)
10
+ - 32000 vocab size
11
+ - 512 max sequence length
12
+ - RoPE position encoding
13
+ - RMSNorm
14
+ - SwiGLU activation
15
+ - Weight tying
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import math
22
+ from typing import Dict, Any, Optional
23
+
24
+
25
+ # Model configuration
26
+ MODEL_CONFIG = {
27
+ "vocab_size": 32000,
28
+ "hidden_size": 512,
29
+ "num_layers": 12,
30
+ "num_heads": 8,
31
+ "intermediate_size": 1408,
32
+ "max_position_embeddings": 512,
33
+ "dropout": 0.0,
34
+ "tie_weights": True,
35
+ }
36
+
37
+
38
+ class RMSNorm(nn.Module):
39
+ """Root Mean Square Layer Normalization."""
40
+ def __init__(self, dim: int, eps: float = 1e-6):
41
+ super().__init__()
42
+ self.eps = eps
43
+ self.weight = nn.Parameter(torch.ones(dim))
44
+
45
+ def forward(self, x):
46
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
47
+
48
+
49
+ class RotaryEmbedding(nn.Module):
50
+ """Rotary Position Embedding (RoPE)."""
51
+ def __init__(self, dim: int, max_seq_len: int = 512, base: int = 10000):
52
+ super().__init__()
53
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
54
+ self.register_buffer("inv_freq", inv_freq)
55
+ self.max_seq_len = max_seq_len
56
+
57
+ def forward(self, seq_len: int, device):
58
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
59
+ freqs = torch.outer(t, self.inv_freq)
60
+ emb = torch.cat((freqs, freqs), dim=-1)
61
+ return emb.cos(), emb.sin()
62
+
63
+
64
+ def rotate_half(x):
65
+ """Rotate half the hidden dims of the input."""
66
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
67
+ return torch.cat((-x2, x1), dim=-1)
68
+
69
+
70
+ def apply_rotary_pos_emb(q, k, cos, sin):
71
+ """Apply rotary positional embeddings to query and key tensors."""
72
+ cos = cos.unsqueeze(0).unsqueeze(0)
73
+ sin = sin.unsqueeze(0).unsqueeze(0)
74
+ q_embed = (q * cos) + (rotate_half(q) * sin)
75
+ k_embed = (k * cos) + (rotate_half(k) * sin)
76
+ return q_embed, k_embed
77
+
78
+
79
+ class Attention(nn.Module):
80
+ """Multi-head attention with RoPE."""
81
+ def __init__(self, config: Dict[str, Any]):
82
+ super().__init__()
83
+ self.hidden_size = config["hidden_size"]
84
+ self.num_heads = config["num_heads"]
85
+ self.head_dim = self.hidden_size // self.num_heads
86
+
87
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
88
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
89
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
90
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
91
+
92
+ self.rotary = RotaryEmbedding(self.head_dim, config["max_position_embeddings"])
93
+ self.dropout = nn.Dropout(config.get("dropout", 0.0))
94
+
95
+ def forward(self, x, attention_mask=None):
96
+ B, T, C = x.shape
97
+
98
+ q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
99
+ k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
100
+ v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
101
+
102
+ cos, sin = self.rotary(T, x.device)
103
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
104
+
105
+ # Scaled dot-product attention
106
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
107
+
108
+ # Causal mask
109
+ causal_mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
110
+ attn_weights.masked_fill_(causal_mask, float('-inf'))
111
+
112
+ if attention_mask is not None:
113
+ attn_weights = attn_weights + attention_mask
114
+
115
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(x.dtype)
116
+ attn_weights = self.dropout(attn_weights)
117
+
118
+ out = torch.matmul(attn_weights, v)
119
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
120
+ return self.o_proj(out)
121
+
122
+
123
+ class FFN(nn.Module):
124
+ """Feed-forward network with SwiGLU activation."""
125
+ def __init__(self, config: Dict[str, Any]):
126
+ super().__init__()
127
+ # SwiGLU: w1=gate, w2=down, w3=up
128
+ self.w1 = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False)
129
+ self.w2 = nn.Linear(config["intermediate_size"], config["hidden_size"], bias=False)
130
+ self.w3 = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False)
131
+
132
+ def forward(self, x):
133
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
134
+
135
+
136
+ class TransformerBlock(nn.Module):
137
+ """Transformer block with pre-norm architecture."""
138
+ def __init__(self, config: Dict[str, Any]):
139
+ super().__init__()
140
+ self.norm1 = RMSNorm(config["hidden_size"])
141
+ self.attn = Attention(config)
142
+ self.norm2 = RMSNorm(config["hidden_size"])
143
+ self.ffn = FFN(config)
144
+
145
+ def forward(self, x, attention_mask=None):
146
+ x = x + self.attn(self.norm1(x), attention_mask)
147
+ x = x + self.ffn(self.norm2(x))
148
+ return x
149
+
150
+
151
+ class TinyLLM(nn.Module):
152
+ """
153
+ TinyLLM: A small decoder-only transformer language model.
154
+
155
+ Args:
156
+ config: Dictionary containing model configuration
157
+ """
158
+ def __init__(self, config: Dict[str, Any] = None):
159
+ super().__init__()
160
+ self.config = config or MODEL_CONFIG
161
+
162
+ self.embed_tokens = nn.Embedding(self.config["vocab_size"], self.config["hidden_size"])
163
+ self.layers = nn.ModuleList([
164
+ TransformerBlock(self.config)
165
+ for _ in range(self.config["num_layers"])
166
+ ])
167
+ self.norm = RMSNorm(self.config["hidden_size"])
168
+ self.lm_head = nn.Linear(self.config["hidden_size"], self.config["vocab_size"], bias=False)
169
+
170
+ # Tie embeddings if configured
171
+ if self.config.get("tie_weights", True):
172
+ self.lm_head.weight = self.embed_tokens.weight
173
+
174
+ # Register causal mask buffer
175
+ max_len = self.config["max_position_embeddings"]
176
+ self.register_buffer("causal_mask",
177
+ torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1))
178
+
179
+ def forward(self, input_ids, attention_mask=None, labels=None):
180
+ x = self.embed_tokens(input_ids)
181
+
182
+ for layer in self.layers:
183
+ x = layer(x, attention_mask)
184
+
185
+ x = self.norm(x)
186
+ logits = self.lm_head(x)
187
+
188
+ loss = None
189
+ if labels is not None:
190
+ shift_logits = logits[..., :-1, :].contiguous()
191
+ shift_labels = labels[..., 1:].contiguous()
192
+ loss = F.cross_entropy(
193
+ shift_logits.view(-1, self.config["vocab_size"]),
194
+ shift_labels.view(-1),
195
+ ignore_index=-100
196
+ )
197
+
198
+ return {"logits": logits, "loss": loss}
199
+
200
+ @torch.no_grad()
201
+ def generate(
202
+ self,
203
+ input_ids: torch.Tensor,
204
+ max_new_tokens: int = 100,
205
+ temperature: float = 0.8,
206
+ top_p: float = 0.9,
207
+ top_k: int = 50,
208
+ eos_token_id: Optional[int] = None,
209
+ repetition_penalty: float = 1.0,
210
+ ) -> torch.Tensor:
211
+ """
212
+ Generate text autoregressively.
213
+
214
+ Args:
215
+ input_ids: Input token IDs [batch_size, seq_len]
216
+ max_new_tokens: Maximum number of tokens to generate
217
+ temperature: Sampling temperature (higher = more random)
218
+ top_p: Nucleus sampling threshold
219
+ top_k: Top-k sampling threshold
220
+ eos_token_id: Token ID that signals end of generation
221
+ repetition_penalty: Penalty for repeating tokens
222
+
223
+ Returns:
224
+ Generated token IDs including the prompt
225
+ """
226
+ self.eval()
227
+
228
+ for _ in range(max_new_tokens):
229
+ # Truncate if needed
230
+ if input_ids.size(1) >= self.config["max_position_embeddings"]:
231
+ input_ids = input_ids[:, -self.config["max_position_embeddings"]+1:]
232
+
233
+ outputs = self(input_ids)
234
+ logits = outputs["logits"][:, -1, :]
235
+
236
+ # Apply repetition penalty
237
+ if repetition_penalty != 1.0:
238
+ for token_id in set(input_ids[0].tolist()):
239
+ logits[0, token_id] /= repetition_penalty
240
+
241
+ # Apply temperature
242
+ logits = logits / temperature
243
+
244
+ # Top-k filtering
245
+ if top_k > 0:
246
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
247
+ logits[indices_to_remove] = float('-inf')
248
+
249
+ # Top-p (nucleus) filtering
250
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
251
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
252
+ sorted_indices_to_remove = cumulative_probs > top_p
253
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
254
+ sorted_indices_to_remove[..., 0] = 0
255
+
256
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
257
+ logits[indices_to_remove] = float('-inf')
258
+
259
+ # Sample
260
+ probs = F.softmax(logits, dim=-1)
261
+ next_token = torch.multinomial(probs, num_samples=1)
262
+ input_ids = torch.cat([input_ids, next_token], dim=1)
263
+
264
+ # Check for EOS
265
+ if eos_token_id is not None and next_token.item() == eos_token_id:
266
+ break
267
+
268
+ return input_ids
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ transformers>=4.35.0
4
+ huggingface_hub>=0.20.0