jonmabe commited on
Commit
2c215c6
·
verified ·
1 Parent(s): dd79935

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. .gitignore +3 -0
  2. README.md +24 -6
  3. app.py +271 -0
  4. best_model.pt +3 -0
  5. model.py +268 -0
  6. requirements.txt +4 -0
  7. tokenizer.json +0 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
README.md CHANGED
@@ -1,12 +1,30 @@
1
  ---
2
- title: Tiny Llm Cli Sft Demo
3
- emoji: 💻
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.5.1
 
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CLI Command Generator
3
+ emoji: 🖥️
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
+ python_version: "3.11"
9
  app_file: app.py
10
  pinned: false
11
+ license: apache-2.0
12
  ---
13
 
14
+ # CLI Command Generator
15
+
16
+ A **54 million parameter** language model fine-tuned to generate shell commands from natural language.
17
+
18
+ ## About
19
+
20
+ Translate instructions like "list all files" → `ls -la`
21
+
22
+ ## Model
23
+
24
+ [jonmabe/tiny-llm-cli-sft](https://huggingface.co/jonmabe/tiny-llm-cli-sft)
25
+
26
+ Fine-tuned on ~13,000 natural language → CLI command pairs.
27
+
28
+ ## Limitations
29
+
30
+ ⚠️ **Experimental** - outputs may be incomplete or incorrect.
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tiny-LLM CLI SFT Demo - Generate Shell Commands from Natural Language
3
+
4
+ This model was fine-tuned to translate natural language instructions to CLI commands.
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from huggingface_hub import hf_hub_download
10
+ from model import TinyLLM
11
+
12
+ # Default model configuration
13
+ MODEL_CONFIG = {
14
+ "vocab_size": 32000,
15
+ "hidden_size": 512,
16
+ "num_layers": 12,
17
+ "num_heads": 8,
18
+ "intermediate_size": 1408,
19
+ "max_position_embeddings": 512,
20
+ "dropout": 0.0,
21
+ "tie_weights": True,
22
+ }
23
+
24
+ # Model configuration
25
+ MODEL_ID = "jonmabe/tiny-llm-cli-sft"
26
+ MODEL_FILENAME = "best_model.pt"
27
+
28
+ # Load tokenizer
29
+ try:
30
+ from tokenizers import Tokenizer
31
+ tokenizer_path = hf_hub_download(repo_id=MODEL_ID, filename="tokenizer.json")
32
+ tokenizer = Tokenizer.from_file(tokenizer_path)
33
+ print("Loaded tokenizer from model repo")
34
+ except Exception as e:
35
+ print(f"Could not load tokenizer: {e}")
36
+ tokenizer = None
37
+
38
+ # Load model
39
+ print("Downloading model...")
40
+ model_path = hf_hub_download(repo_id=MODEL_ID, filename=MODEL_FILENAME)
41
+ print(f"Model downloaded to {model_path}")
42
+
43
+ print("Loading model...")
44
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
45
+
46
+ # Get config from checkpoint if available
47
+ if "config" in checkpoint and isinstance(checkpoint["config"], dict):
48
+ config = checkpoint["config"]
49
+ if "model" in config:
50
+ config = config["model"]
51
+ else:
52
+ config = MODEL_CONFIG
53
+
54
+ # Initialize model
55
+ model = TinyLLM(config)
56
+
57
+ # Load weights
58
+ if "model_state_dict" in checkpoint:
59
+ state_dict = checkpoint["model_state_dict"]
60
+ else:
61
+ state_dict = checkpoint
62
+
63
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
64
+ if missing:
65
+ print(f"Warning: Missing keys: {missing[:5]}...")
66
+ if unexpected:
67
+ print(f"Warning: Unexpected keys: {unexpected[:5]}...")
68
+
69
+ # Move to device
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ model = model.to(device)
72
+ model.eval()
73
+
74
+ total_params = sum(p.numel() for p in model.parameters())
75
+ print(f"Model loaded on {device} with {total_params:,} parameters")
76
+
77
+
78
+ def clean_bpe_output(text: str) -> str:
79
+ """Clean BPE artifacts from tokenizer output."""
80
+ # Replace BPE space marker with actual space
81
+ text = text.replace("Ġ", " ")
82
+ # Replace BPE newline marker with actual newline
83
+ text = text.replace("Ċ", "\n")
84
+ # Clean up extra spaces
85
+ text = " ".join(text.split())
86
+ return text.strip()
87
+
88
+
89
+ def generate_command(
90
+ instruction: str,
91
+ max_tokens: int = 50,
92
+ temperature: float = 0.7,
93
+ top_p: float = 0.9,
94
+ top_k: int = 50,
95
+ ) -> str:
96
+ """Generate a CLI command from an instruction."""
97
+
98
+ if not instruction.strip():
99
+ return "Please enter an instruction."
100
+
101
+ if tokenizer is None:
102
+ return "Tokenizer not available."
103
+
104
+ # Format prompt
105
+ prompt = f"Instruction: {instruction}\nCommand:"
106
+
107
+ # Tokenize
108
+ encoded = tokenizer.encode(prompt)
109
+ input_ids = torch.tensor([encoded.ids], dtype=torch.long).to(device)
110
+ input_len = input_ids.shape[1]
111
+
112
+ # Generate
113
+ with torch.no_grad():
114
+ output_ids = model.generate(
115
+ input_ids,
116
+ max_new_tokens=max_tokens,
117
+ temperature=temperature,
118
+ top_p=top_p,
119
+ top_k=top_k,
120
+ eos_token_id=tokenizer.token_to_id("</s>"),
121
+ )
122
+
123
+ # Decode only the generated tokens
124
+ generated_ids = output_ids[0, input_len:].tolist()
125
+ raw_output = tokenizer.decode(generated_ids)
126
+
127
+ # Clean BPE artifacts
128
+ command = clean_bpe_output(raw_output)
129
+
130
+ # Extract just the command (first line, stop at newline)
131
+ command = command.split("\n")[0].strip()
132
+
133
+ return command
134
+
135
+
136
+ # Example instructions
137
+ EXAMPLES = [
138
+ ["List all files in the current directory"],
139
+ ["Find all Python files"],
140
+ ["Show disk usage"],
141
+ ["Create a new folder called test"],
142
+ ["Search for 'error' in log files"],
143
+ ["Show the last 10 lines of a file"],
144
+ ["Count lines in a file"],
145
+ ["Copy files to another directory"],
146
+ ["Show running processes"],
147
+ ["Check available disk space"],
148
+ ]
149
+
150
+
151
+ # Create Gradio interface
152
+ with gr.Blocks(title="CLI Command Generator") as demo:
153
+ gr.Markdown("""
154
+ # 🖥️ CLI Command Generator
155
+
156
+ Translate natural language instructions to shell commands using a **54M parameter** language model.
157
+
158
+ ⚠️ **Note**: This is an early-stage SFT model. Outputs may be incomplete or incorrect.
159
+
160
+ ### How to Use
161
+ 1. Enter a natural language instruction
162
+ 2. Click "Generate" or press Enter
163
+ 3. The model will suggest a shell command
164
+ """)
165
+
166
+ with gr.Row():
167
+ with gr.Column(scale=2):
168
+ instruction_input = gr.Textbox(
169
+ label="Instruction",
170
+ placeholder="Describe what you want to do...",
171
+ lines=2,
172
+ value="List all files in the current directory"
173
+ )
174
+
175
+ with gr.Row():
176
+ with gr.Column():
177
+ max_tokens = gr.Slider(
178
+ minimum=10,
179
+ maximum=100,
180
+ value=50,
181
+ step=5,
182
+ label="Max Tokens",
183
+ )
184
+ temperature = gr.Slider(
185
+ minimum=0.1,
186
+ maximum=1.5,
187
+ value=0.7,
188
+ step=0.1,
189
+ label="Temperature",
190
+ info="Higher = more creative"
191
+ )
192
+
193
+ with gr.Column():
194
+ top_p = gr.Slider(
195
+ minimum=0.1,
196
+ maximum=1.0,
197
+ value=0.9,
198
+ step=0.05,
199
+ label="Top-p",
200
+ )
201
+ top_k = gr.Slider(
202
+ minimum=1,
203
+ maximum=100,
204
+ value=50,
205
+ step=5,
206
+ label="Top-k",
207
+ )
208
+
209
+ generate_btn = gr.Button("⚡ Generate Command", variant="primary", size="lg")
210
+
211
+ with gr.Column(scale=2):
212
+ output_command = gr.Textbox(
213
+ label="Generated Command",
214
+ lines=3,
215
+ interactive=False,
216
+ )
217
+
218
+ gr.Markdown("""
219
+ ### Common Commands Reference
220
+ - `ls` - list files
221
+ - `find` - search for files
222
+ - `grep` - search in files
223
+ - `df` - disk usage
224
+ - `du` - directory size
225
+ - `tar` - archive files
226
+ - `scp` - copy over SSH
227
+ """)
228
+
229
+ gr.Markdown("### 📝 Example Instructions")
230
+ gr.Examples(
231
+ examples=EXAMPLES,
232
+ inputs=instruction_input,
233
+ )
234
+
235
+ # Event handlers
236
+ generate_btn.click(
237
+ fn=generate_command,
238
+ inputs=[instruction_input, max_tokens, temperature, top_p, top_k],
239
+ outputs=output_command,
240
+ )
241
+
242
+ instruction_input.submit(
243
+ fn=generate_command,
244
+ inputs=[instruction_input, max_tokens, temperature, top_p, top_k],
245
+ outputs=output_command,
246
+ )
247
+
248
+ gr.Markdown("""
249
+ ---
250
+ ### About This Model
251
+
252
+ **Model**: [jonmabe/tiny-llm-cli-sft](https://huggingface.co/jonmabe/tiny-llm-cli-sft)
253
+
254
+ This is a Supervised Fine-Tuned (SFT) version of [tiny-llm-54m](https://huggingface.co/jonmabe/tiny-llm-54m),
255
+ trained on ~13,000 natural language → CLI command pairs.
256
+
257
+ #### Known Limitations
258
+ - 🔬 **Experimental**: Outputs may be incomplete or incorrect
259
+ - 📊 **Small model**: 54M parameters limits capability
260
+ - 🔧 **Needs improvement**: More training data and steps needed
261
+
262
+ #### Training Details
263
+ - **Steps**: 2,000
264
+ - **Best Val Loss**: 1.2456
265
+ - **Data**: Geddy's NL2Bash + NL2Bash benchmark + synthetic
266
+ - **Hardware**: RTX 5090, ~9 minutes
267
+ """)
268
+
269
+
270
+ if __name__ == "__main__":
271
+ demo.launch()
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b249acfcfa538c4e0d9b6aba62d8e41e5f8de481d94a0214450a004ab1cea540
3
+ size 220038796
model.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyLLM Model Architecture
3
+
4
+ A small transformer language model (~54.93M parameters) trained from scratch.
5
+ Architecture:
6
+ - 12 layers
7
+ - 512 hidden size
8
+ - 8 attention heads
9
+ - 1408 intermediate (FFN)
10
+ - 32000 vocab size
11
+ - 512 max sequence length
12
+ - RoPE position encoding
13
+ - RMSNorm
14
+ - SwiGLU activation
15
+ - Weight tying
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import math
22
+ from typing import Dict, Any, Optional
23
+
24
+
25
+ # Model configuration
26
+ MODEL_CONFIG = {
27
+ "vocab_size": 32000,
28
+ "hidden_size": 512,
29
+ "num_layers": 12,
30
+ "num_heads": 8,
31
+ "intermediate_size": 1408,
32
+ "max_position_embeddings": 512,
33
+ "dropout": 0.0,
34
+ "tie_weights": True,
35
+ }
36
+
37
+
38
+ class RMSNorm(nn.Module):
39
+ """Root Mean Square Layer Normalization."""
40
+ def __init__(self, dim: int, eps: float = 1e-6):
41
+ super().__init__()
42
+ self.eps = eps
43
+ self.weight = nn.Parameter(torch.ones(dim))
44
+
45
+ def forward(self, x):
46
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
47
+
48
+
49
+ class RotaryEmbedding(nn.Module):
50
+ """Rotary Position Embedding (RoPE)."""
51
+ def __init__(self, dim: int, max_seq_len: int = 512, base: int = 10000):
52
+ super().__init__()
53
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
54
+ self.register_buffer("inv_freq", inv_freq)
55
+ self.max_seq_len = max_seq_len
56
+
57
+ def forward(self, seq_len: int, device):
58
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
59
+ freqs = torch.outer(t, self.inv_freq)
60
+ emb = torch.cat((freqs, freqs), dim=-1)
61
+ return emb.cos(), emb.sin()
62
+
63
+
64
+ def rotate_half(x):
65
+ """Rotate half the hidden dims of the input."""
66
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
67
+ return torch.cat((-x2, x1), dim=-1)
68
+
69
+
70
+ def apply_rotary_pos_emb(q, k, cos, sin):
71
+ """Apply rotary positional embeddings to query and key tensors."""
72
+ cos = cos.unsqueeze(0).unsqueeze(0)
73
+ sin = sin.unsqueeze(0).unsqueeze(0)
74
+ q_embed = (q * cos) + (rotate_half(q) * sin)
75
+ k_embed = (k * cos) + (rotate_half(k) * sin)
76
+ return q_embed, k_embed
77
+
78
+
79
+ class Attention(nn.Module):
80
+ """Multi-head attention with RoPE."""
81
+ def __init__(self, config: Dict[str, Any]):
82
+ super().__init__()
83
+ self.hidden_size = config["hidden_size"]
84
+ self.num_heads = config["num_heads"]
85
+ self.head_dim = self.hidden_size // self.num_heads
86
+
87
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
88
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
89
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
90
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
91
+
92
+ self.rotary = RotaryEmbedding(self.head_dim, config["max_position_embeddings"])
93
+ self.dropout = nn.Dropout(config.get("dropout", 0.0))
94
+
95
+ def forward(self, x, attention_mask=None):
96
+ B, T, C = x.shape
97
+
98
+ q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
99
+ k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
100
+ v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
101
+
102
+ cos, sin = self.rotary(T, x.device)
103
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
104
+
105
+ # Scaled dot-product attention
106
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
107
+
108
+ # Causal mask
109
+ causal_mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
110
+ attn_weights.masked_fill_(causal_mask, float('-inf'))
111
+
112
+ if attention_mask is not None:
113
+ attn_weights = attn_weights + attention_mask
114
+
115
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(x.dtype)
116
+ attn_weights = self.dropout(attn_weights)
117
+
118
+ out = torch.matmul(attn_weights, v)
119
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
120
+ return self.o_proj(out)
121
+
122
+
123
+ class FFN(nn.Module):
124
+ """Feed-forward network with SwiGLU activation."""
125
+ def __init__(self, config: Dict[str, Any]):
126
+ super().__init__()
127
+ # SwiGLU: w1=gate, w2=down, w3=up
128
+ self.w1 = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False)
129
+ self.w2 = nn.Linear(config["intermediate_size"], config["hidden_size"], bias=False)
130
+ self.w3 = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False)
131
+
132
+ def forward(self, x):
133
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
134
+
135
+
136
+ class TransformerBlock(nn.Module):
137
+ """Transformer block with pre-norm architecture."""
138
+ def __init__(self, config: Dict[str, Any]):
139
+ super().__init__()
140
+ self.norm1 = RMSNorm(config["hidden_size"])
141
+ self.attn = Attention(config)
142
+ self.norm2 = RMSNorm(config["hidden_size"])
143
+ self.ffn = FFN(config)
144
+
145
+ def forward(self, x, attention_mask=None):
146
+ x = x + self.attn(self.norm1(x), attention_mask)
147
+ x = x + self.ffn(self.norm2(x))
148
+ return x
149
+
150
+
151
+ class TinyLLM(nn.Module):
152
+ """
153
+ TinyLLM: A small decoder-only transformer language model.
154
+
155
+ Args:
156
+ config: Dictionary containing model configuration
157
+ """
158
+ def __init__(self, config: Dict[str, Any] = None):
159
+ super().__init__()
160
+ self.config = config or MODEL_CONFIG
161
+
162
+ self.embed_tokens = nn.Embedding(self.config["vocab_size"], self.config["hidden_size"])
163
+ self.layers = nn.ModuleList([
164
+ TransformerBlock(self.config)
165
+ for _ in range(self.config["num_layers"])
166
+ ])
167
+ self.norm = RMSNorm(self.config["hidden_size"])
168
+ self.lm_head = nn.Linear(self.config["hidden_size"], self.config["vocab_size"], bias=False)
169
+
170
+ # Tie embeddings if configured
171
+ if self.config.get("tie_weights", True):
172
+ self.lm_head.weight = self.embed_tokens.weight
173
+
174
+ # Register causal mask buffer
175
+ max_len = self.config["max_position_embeddings"]
176
+ self.register_buffer("causal_mask",
177
+ torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1))
178
+
179
+ def forward(self, input_ids, attention_mask=None, labels=None):
180
+ x = self.embed_tokens(input_ids)
181
+
182
+ for layer in self.layers:
183
+ x = layer(x, attention_mask)
184
+
185
+ x = self.norm(x)
186
+ logits = self.lm_head(x)
187
+
188
+ loss = None
189
+ if labels is not None:
190
+ shift_logits = logits[..., :-1, :].contiguous()
191
+ shift_labels = labels[..., 1:].contiguous()
192
+ loss = F.cross_entropy(
193
+ shift_logits.view(-1, self.config["vocab_size"]),
194
+ shift_labels.view(-1),
195
+ ignore_index=-100
196
+ )
197
+
198
+ return {"logits": logits, "loss": loss}
199
+
200
+ @torch.no_grad()
201
+ def generate(
202
+ self,
203
+ input_ids: torch.Tensor,
204
+ max_new_tokens: int = 100,
205
+ temperature: float = 0.8,
206
+ top_p: float = 0.9,
207
+ top_k: int = 50,
208
+ eos_token_id: Optional[int] = None,
209
+ repetition_penalty: float = 1.0,
210
+ ) -> torch.Tensor:
211
+ """
212
+ Generate text autoregressively.
213
+
214
+ Args:
215
+ input_ids: Input token IDs [batch_size, seq_len]
216
+ max_new_tokens: Maximum number of tokens to generate
217
+ temperature: Sampling temperature (higher = more random)
218
+ top_p: Nucleus sampling threshold
219
+ top_k: Top-k sampling threshold
220
+ eos_token_id: Token ID that signals end of generation
221
+ repetition_penalty: Penalty for repeating tokens
222
+
223
+ Returns:
224
+ Generated token IDs including the prompt
225
+ """
226
+ self.eval()
227
+
228
+ for _ in range(max_new_tokens):
229
+ # Truncate if needed
230
+ if input_ids.size(1) >= self.config["max_position_embeddings"]:
231
+ input_ids = input_ids[:, -self.config["max_position_embeddings"]+1:]
232
+
233
+ outputs = self(input_ids)
234
+ logits = outputs["logits"][:, -1, :]
235
+
236
+ # Apply repetition penalty
237
+ if repetition_penalty != 1.0:
238
+ for token_id in set(input_ids[0].tolist()):
239
+ logits[0, token_id] /= repetition_penalty
240
+
241
+ # Apply temperature
242
+ logits = logits / temperature
243
+
244
+ # Top-k filtering
245
+ if top_k > 0:
246
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
247
+ logits[indices_to_remove] = float('-inf')
248
+
249
+ # Top-p (nucleus) filtering
250
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
251
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
252
+ sorted_indices_to_remove = cumulative_probs > top_p
253
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
254
+ sorted_indices_to_remove[..., 0] = 0
255
+
256
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
257
+ logits[indices_to_remove] = float('-inf')
258
+
259
+ # Sample
260
+ probs = F.softmax(logits, dim=-1)
261
+ next_token = torch.multinomial(probs, num_samples=1)
262
+ input_ids = torch.cat([input_ids, next_token], dim=1)
263
+
264
+ # Check for EOS
265
+ if eos_token_id is not None and next_token.item() == eos_token_id:
266
+ break
267
+
268
+ return input_ids
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==4.44.0
2
+ torch>=2.0.0
3
+ huggingface_hub>=0.24.0,<0.27.0
4
+ tokenizers>=0.15.0
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff