| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import json |
| | import os |
| | import math |
| | import gradio as gr |
| | from transformers import GPT2Tokenizer |
| | from huggingface_hub import hf_hub_download |
| | from typing import Optional, List, Tuple |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class BTreeNode: |
| | """B+ Tree node stub for inference.""" |
| | def __init__(self, order: int, is_leaf: bool = False, device: str = 'cuda'): |
| | pass |
| |
|
| | class BTreeAttentionIndex: |
| | """B+ Tree structure stub for inference.""" |
| | def __init__(self, order: int = 5, device: str = 'cuda'): |
| | pass |
| | def insert(self, key: int, value: torch.Tensor): |
| | pass |
| | def range_query(self, start: int, end: int) -> List[torch.Tensor]: |
| | return [] |
| |
|
| | class StandardSelfAttention(nn.Module): |
| | """Standard multi-head self-attention mechanism.""" |
| | def __init__(self, d_model, n_heads, use_position_bias=False, block_size=None, dropout=0.1): |
| | super().__init__() |
| | self.d_model = d_model |
| | self.n_heads = n_heads |
| | self.d_head = d_model // n_heads |
| | self.q_proj = nn.Linear(d_model, d_model) |
| | self.k_proj = nn.Linear(d_model, d_model) |
| | self.v_proj = nn.Linear(d_model, d_model) |
| | self.out_proj = nn.Linear(d_model, d_model) |
| | self.dropout = nn.Dropout(dropout) |
| | self.use_position_bias = use_position_bias |
| | if use_position_bias: |
| | if block_size is None: |
| | raise ValueError("block_size must be provided if use_position_bias is True") |
| | self.block_size = block_size |
| | self.position_bias = nn.Parameter(torch.randn(1, n_heads, block_size, block_size) * 0.1) |
| |
|
| | def forward(self, x, mask=None, bias=None, **kwargs): |
| | batch_size, seq_len, _ = x.shape |
| | q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2) |
| | k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2) |
| | v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2) |
| | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head) |
| | if self.use_position_bias and seq_len <= self.block_size: |
| | scores = scores + self.position_bias[:, :, :seq_len, :seq_len] |
| | if bias is not None: |
| | scores = scores + bias |
| | if mask is not None: |
| | scores = scores.masked_fill(mask == 0, -float('inf')) |
| | weights = F.softmax(scores, dim=-1) |
| | weights = self.dropout(weights) |
| | context = torch.matmul(weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) |
| | return self.out_proj(context) |
| |
|
| | class OptimizedParallelHierarchicalAttention(nn.Module): |
| | """Enhanced PHA with B+ Tree indexing (inference mode).""" |
| | def __init__(self, d_model, n_heads, block_size, btree_order=5, dropout=0.1, device='cuda'): |
| | super().__init__() |
| | self.d_model, self.n_heads, self.block_size, self.device = d_model, n_heads, block_size, device |
| | self.local_attention = StandardSelfAttention(d_model, n_heads, use_position_bias=True, block_size=block_size, dropout=dropout).to(device) |
| | self.summary_attention = StandardSelfAttention(d_model, n_heads, dropout=dropout).to(device) |
| | self.summarizer = nn.Linear(d_model, d_model).to(device) |
| | self.dropout = nn.Dropout(dropout).to(device) |
| | self.attention_index = BTreeAttentionIndex(btree_order, device=device) |
| |
|
| | def forward(self, x): |
| | batch_size, seq_len, d_model = x.shape |
| | pad_len = (self.block_size - seq_len % self.block_size) % self.block_size |
| | if pad_len > 0: |
| | x = F.pad(x, (0, 0, 0, pad_len)) |
| | padded_len = x.shape[1] |
| | num_blocks = padded_len // self.block_size |
| | x_reshaped = x.view(batch_size * num_blocks, self.block_size, d_model) |
| | local_context = self.local_attention(x_reshaped).view(batch_size, padded_len, d_model) |
| | local_context = self.dropout(local_context) |
| | block_view = local_context.view(batch_size, num_blocks, self.block_size, d_model) |
| | weights = torch.softmax(torch.randn(self.block_size, device=self.device), dim=0) |
| | summary_tokens = self.summarizer((block_view * weights.view(1, 1, -1, 1)).sum(dim=2)) |
| | summary_context = self.summary_attention(summary_tokens) |
| | summary_context = self.dropout(summary_context) |
| | summary_context_expanded = summary_context.unsqueeze(2).expand(-1, -1, self.block_size, -1) |
| | summary_context_distributed = summary_context_expanded.reshape(batch_size, padded_len, d_model) |
| | final_context = local_context + summary_context_distributed |
| | return final_context[:, :seq_len, :] |
| |
|
| | class EnhancedTransformerBlock(nn.Module): |
| | """Enhanced transformer block with B+ tree optimization.""" |
| | def __init__(self, d_model, n_heads, block_size, btree_order, dropout, device='cuda'): |
| | super().__init__() |
| | self.attention = OptimizedParallelHierarchicalAttention(d_model, n_heads, block_size, btree_order, dropout, device) |
| | self.norm1 = nn.LayerNorm(d_model).to(device) |
| | self.norm2 = nn.LayerNorm(d_model).to(device) |
| | self.ffn = nn.Sequential(nn.Linear(d_model, d_model * 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * 2, d_model)).to(device) |
| |
|
| | def forward(self, x): |
| | x = self.norm1(self.attention(x) + x) |
| | x = self.norm2(self.ffn(x) + x) |
| | return x |
| |
|
| | class OptimizedSimpleTransformer(nn.Module): |
| | """The main transformer model class.""" |
| | def __init__(self, vocab_size, num_layers, d_model, n_heads, block_size, btree_order, dropout, device='cuda'): |
| | super().__init__() |
| | self.embedding = nn.Embedding(vocab_size, d_model).to(device) |
| | self.pos_encoding = nn.Parameter(torch.randn(1, 512, d_model) * 0.1).to(device) |
| | self.layers = nn.ModuleList([EnhancedTransformerBlock(d_model, n_heads, block_size, btree_order, dropout, device) for _ in range(num_layers)]) |
| | self.output_head = nn.Linear(d_model, vocab_size).to(device) |
| | self.device, self.vocab_size = device, vocab_size |
| |
|
| | def forward(self, idx): |
| | x = self.embedding(idx) + self.pos_encoding[:, :idx.shape[1], :] |
| | for layer in self.layers: x = layer(x) |
| | return self.output_head(x) |
| |
|
| | @torch.no_grad() |
| | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None, repetition_penalty=1.0): |
| | self.eval() |
| | for _ in range(max_new_tokens): |
| | idx_cond = idx[:, -self.pos_encoding.size(1):] |
| | logits = self(idx_cond)[:, -1, :] / temperature |
| | if repetition_penalty != 1.0: |
| | for token_id in set(idx[0].tolist()): logits[0, token_id] /= repetition_penalty |
| | if top_k is not None: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = -float('inf') |
| | if top_p is not None and top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | logits = logits.masked_fill(indices_to_remove, -float('inf')) |
| | probs = F.softmax(logits, dim=-1) |
| | idx_next = torch.multinomial(probs, num_samples=1) |
| | idx = torch.cat((idx, idx_next), dim=1) |
| | return idx |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | REPO_ID = "TheVixhal/crek" |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | model = None |
| | tokenizer = None |
| |
|
| | print(f"✅ Using device: {device}") |
| | print(f"⏳ Loading model from Hugging Face Hub: '{REPO_ID}'...") |
| |
|
| | try: |
| | |
| | config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") |
| | with open(config_path, 'r') as f: |
| | config = json.load(f) |
| |
|
| | |
| | tokenizer = GPT2Tokenizer.from_pretrained(REPO_ID) |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | |
| | model = OptimizedSimpleTransformer( |
| | vocab_size=config['vocab_size'], num_layers=config['num_layers'], |
| | d_model=config['d_model'], n_heads=config['n_heads'], |
| | block_size=config['block_size'], btree_order=config['btree_order'], |
| | dropout=config['dropout'], device=device |
| | ) |
| |
|
| | |
| | model_weights_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin") |
| | model.load_state_dict(torch.load(model_weights_path, map_location=device)) |
| | model.to(device) |
| | model.eval() |
| |
|
| | print("✅ Model loaded successfully!") |
| |
|
| | except Exception as e: |
| | print(f"❌ An error occurred during model loading: {e}") |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | def generate_response(prompt, max_new_tokens=100, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.5): |
| | """The core function that Gradio will call to run model generation.""" |
| | if model is None or tokenizer is None: |
| | return "Model not loaded. Please check the console logs for errors." |
| |
|
| | |
| | input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) |
| |
|
| | |
| | output_ids = model.generate( |
| | idx=input_ids, |
| | max_new_tokens=int(max_new_tokens), |
| | temperature=temperature, |
| | top_k=int(top_k), |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty |
| | ) |
| |
|
| | |
| | return tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | |
| | iface = gr.Interface( |
| | fn=generate_response, |
| | inputs=[ |
| | gr.Textbox(lines=3, label="Your Prompt", placeholder="Enter your text here..."), |
| | gr.Slider(minimum=10, maximum=500, value=10, step=10, label="Max New Tokens"), |
| | gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, label="Temperature"), |
| | gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K"), |
| | gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus Sampling)"), |
| | gr.Slider(minimum=1.0, maximum=3.0, value=1.5, step=0.1, label="Repetition Penalty") |
| | ], |
| | outputs=gr.Textbox(lines=10, label="Generated Text"), |
| | title="🤖 Crek: A B-Tree Transformer Chatbot", |
| | description="An interface for the custom-trained base model 'crek' by TheVixhal. This model predicts text based on patterns from its training data. Adjust the sliders to control the output's creativity and coherence.", |
| | allow_flagging="never", |
| | |
| | examples=[ |
| | ["hey what's up", 10, 0.5, 50, 0.9, 1.5], |
| | ["how are you?", 10, 0.5, 50, 0.9, 1.5], |
| | ["i am feeling low", 10, 0.5, 50, 0.9, 1.5], |
| | ] |
| | ) |
| |
|
| | |
| | print("\n🚀 Launching Gradio Interface...") |
| | print("Open the public URL in your browser to interact with the model.") |
| | iface.launch(share=True) |