import torch import tiktoken import json from typing import Dict, Optional # Model Architecture Classes class Config: def __init__(self): self.vocab_size = 100283 self.max_position_embeddings = 1024 self.hidden_size = 768 self.num_layers = 6 self.num_heads = 12 self.intermediate_size = 3072 self.dropout = 0.1 class AttentionHead(torch.nn.Module): def __init__(self, config: Config): super().__init__() self.head_dim = config.hidden_size // config.num_heads self.query = torch.nn.Linear(config.hidden_size, self.head_dim) self.key = torch.nn.Linear(config.hidden_size, self.head_dim) self.value = torch.nn.Linear(config.hidden_size, self.head_dim) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: Q = self.query(x) K = self.key(x) V = self.value(x) scale = Q.size(-1) ** 0.5 scores = torch.matmul(Q, K.transpose(-2, -1)) / scale if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) attention = torch.nn.functional.softmax(scores, dim=-1) return torch.matmul(attention, V) class MultiHeadAttention(torch.nn.Module): def __init__(self, config: Config): super().__init__() self.heads = torch.nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)]) self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size) self.dropout = torch.nn.Dropout(config.dropout) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: heads = [head(x, mask) for head in self.heads] multihead = torch.cat(heads, dim=-1) return self.dropout(self.linear(multihead)) class TransformerBlock(torch.nn.Module): def __init__(self, config: Config): super().__init__() self.attention = MultiHeadAttention(config) self.norm1 = torch.nn.LayerNorm(config.hidden_size) self.norm2 = torch.nn.LayerNorm(config.hidden_size) self.feed_forward = torch.nn.Sequential( torch.nn.Linear(config.hidden_size, config.intermediate_size), torch.nn.GELU(), torch.nn.Linear(config.intermediate_size, config.hidden_size), torch.nn.Dropout(config.dropout) ) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: attended = self.attention(x, mask) x = self.norm1(x + attended) fed_forward = self.feed_forward(x) return self.norm2(x + fed_forward) class SmallLanguageModel(torch.nn.Module): def __init__(self, config: Config): super().__init__() self.config = config self.token_embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size) self.position_embedding = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size) self.transformer_blocks = torch.nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)]) self.dropout = torch.nn.Dropout(config.dropout) self.ln_f = torch.nn.LayerNorm(config.hidden_size) self.head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) if isinstance(module, torch.nn.Linear) and module.bias is not None: module.bias.data.zero_() elif isinstance(module, torch.nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def get_causal_mask(self, size: int) -> torch.Tensor: mask = torch.triu(torch.ones(size, size), diagonal=1).bool() return ~mask def forward(self, input_ids: torch.Tensor) -> torch.Tensor: b, t = input_ids.size() positions = torch.arange(0, t, dtype=torch.long, device=input_ids.device) mask = self.get_causal_mask(t).to(input_ids.device) token_embeddings = self.token_embedding(input_ids) position_embeddings = self.position_embedding(positions) x = self.dropout(token_embeddings + position_embeddings) for block in self.transformer_blocks: x = block(x, mask) x = self.ln_f(x) logits = self.head(x) return logits # Text Generator Class class TextGenerator: def __init__(self, model, tokenizer): self.model = model self.model.eval() self.tokenizer = tokenizer @torch.no_grad() def generate( self, prompt: str, max_length: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.9 ) -> Dict[str, str]: try: input_ids = torch.tensor(self.tokenizer.encode( prompt, allowed_special={'', '', '', '', '', ''} )).unsqueeze(0).to(device) for _ in range(max_length): if input_ids.size(1) > config.max_position_embeddings: input_ids = input_ids[:, -config.max_position_embeddings:] logits = self.model(input_ids) next_token_logits = logits[:, -1, :] / temperature if top_k > 0: values, _ = torch.topk(next_token_logits, top_k) min_value = values[:, -1].unsqueeze(-1) next_token_logits = torch.where( next_token_logits < min_value, torch.tensor(float('-inf')).to(device), next_token_logits ) if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf')) probs = torch.nn.functional.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat((input_ids, next_token), dim=1) generated_text = self.tokenizer.decode(input_ids[0].tolist()) return { "status": "success", "generated_text": generated_text, "prompt": prompt, "max_length": max_length, "temperature": temperature, "top_k": top_k, "top_p": top_p } except Exception as e: return { "status": "error", "error_message": str(e), "prompt": prompt } # Helper Function to Load Model and Tokenizer def load_model_and_tokenizer(checkpoint_path: str) -> Tuple[SmallLanguageModel, tiktoken.Encoding]: config = Config() cl100k_base = tiktoken.get_encoding("cl100k_base") tokenizer = tiktoken.Encoding( name="cl100k_xml", pat_str=cl100k_base._pat_str, mergeable_ranks=cl100k_base._mergeable_ranks, special_tokens={ **cl100k_base._special_tokens, "": 100277, "": 100278, "": 100279, "": 100280, "": 100281, "": 100282 } ) config.vocab_size = tokenizer.n_vocab model = SmallLanguageModel(config) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) return model, tokenizer # Main Function for Inference def generate( checkpoint_path: str, prompt: str, max_length: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.9 ) -> Dict[str, str]: global device, config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model and tokenizer model, tokenizer = load_model_and_tokenizer(checkpoint_path) # Generate text generator = TextGenerator(model, tokenizer) result = generator.generate( prompt=prompt, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p ) return result