| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | import tiktoken |
| | import gradio as gr |
| | import asyncio |
| |
|
| | |
| | class GPTConfig: |
| | def __init__(self): |
| | self.block_size = 1024 |
| | self.vocab_size = 50304 |
| | self.n_layer = 12 |
| | self.n_head = 12 |
| | self.n_embd = 768 |
| |
|
| | |
| | class CausalSelfAttention(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | assert config.n_embd % config.n_head == 0 |
| | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) |
| | self.c_proj = nn.Linear(config.n_embd, config.n_embd) |
| | self.n_head = config.n_head |
| | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size)) |
| |
|
| | def forward(self, x): |
| | B, T, C = x.size() |
| | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| | y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) |
| | y = y.transpose(1, 2).contiguous().view(B, T, C) |
| | return self.c_proj(y) |
| |
|
| | |
| | class MLP(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) |
| | self.gelu = nn.GELU() |
| | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) |
| |
|
| | def forward(self, x): |
| | return self.c_proj(self.gelu(self.c_fc(x))) |
| |
|
| | |
| | class Block(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.ln_1 = nn.LayerNorm(config.n_embd) |
| | self.attn = CausalSelfAttention(config) |
| | self.ln_2 = nn.LayerNorm(config.n_embd) |
| | self.mlp = MLP(config) |
| |
|
| | def forward(self, x): |
| | x = x + self.attn(self.ln_1(x)) |
| | x = x + self.mlp(self.ln_2(x)) |
| | return x |
| |
|
| | |
| | class GPT(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.transformer = nn.ModuleDict(dict( |
| | wte = nn.Embedding(config.vocab_size, config.n_embd), |
| | wpe = nn.Embedding(config.block_size, config.n_embd), |
| | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| | ln_f = nn.LayerNorm(config.n_embd), |
| | )) |
| | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| | self.transformer.wte.weight = self.lm_head.weight |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
|
| | def forward(self, idx, targets=None): |
| | device = idx.device |
| | b, t = idx.size() |
| | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
| | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) |
| | |
| | tok_emb = self.transformer.wte(idx) |
| | pos_emb = self.transformer.wpe(pos) |
| | x = tok_emb + pos_emb |
| | for block in self.transformer.h: |
| | x = block(x) |
| | x = self.transformer.ln_f(x) |
| | logits = self.lm_head(x) |
| | |
| | loss = None |
| | if targets is not None: |
| | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
| | |
| | return logits, loss |
| |
|
| | |
| | def load_model(model_path): |
| | config = GPTConfig() |
| | model = GPT(config) |
| | try: |
| | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) |
| | except FileNotFoundError: |
| | raise FileNotFoundError(f"Model file not found at: {model_path}") |
| | except Exception as e: |
| | raise Exception(f"Error loading model: {e}") |
| |
|
| | if 'model_state_dict' in checkpoint: |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | else: |
| | model.load_state_dict(checkpoint) |
| | |
| | model.eval() |
| | return model |
| |
|
| |
|
| | |
| | def post_process_text(text): |
| | text = text.capitalize() |
| | sentences = text.split('.') |
| | complete_sentences = sentences[:-1] if len(sentences) > 1 else sentences |
| | processed_text = '. '.join(complete_sentences) |
| | if not processed_text.endswith('.'): |
| | processed_text += '.' |
| | return processed_text |
| |
|
| | |
| | async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40): |
| | enc = tiktoken.get_encoding('gpt2') |
| | input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device) |
| | generated = [] |
| |
|
| | with torch.no_grad(): |
| | for _ in range(max_length): |
| | try: |
| | outputs, _ = model(input_ids) |
| | next_token_logits = outputs[:, -1, :] |
| | next_token_logits = next_token_logits / temperature |
| | top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1) |
| | next_token_probs = F.softmax(top_k_logits, dim=-1) |
| | next_token_index = torch.multinomial(next_token_probs, num_samples=1) |
| | next_token = top_k_indices.gather(-1, next_token_index) |
| |
|
| | input_ids = torch.cat([input_ids, next_token], dim=-1) |
| | generated.append(next_token.item()) |
| |
|
| | next_token_str = enc.decode([next_token.item()]) |
| | yield next_token_str |
| |
|
| | if next_token.item() == enc.encode('\n')[0] and len(generated) > 100: |
| | break |
| |
|
| | await asyncio.sleep(0.02) |
| |
|
| | except Exception as e: |
| | yield f"Error during generation: {e}" |
| | return |
| |
|
| | |
| | async def gradio_generate(prompt, max_length, temperature, top_k): |
| | output = "" |
| | async for token in generate_text(prompt, max_length, temperature, top_k): |
| | output += token |
| | yield output |
| | output = post_process_text(output) |
| | yield output |
| |
|
| | |
| |
|
| | css = """ |
| | <style> |
| | body { |
| | background-color: #0f1624; |
| | color: #e0e0e0; |
| | font-family: 'Courier New', monospace; |
| | background-image: |
| | radial-gradient(white, rgba(255,255,255,.2) 2px, transparent 40px), |
| | radial-gradient(white, rgba(255,255,255,.15) 1px, transparent 30px), |
| | radial-gradient(white, rgba(255,255,255,.1) 2px, transparent 40px), |
| | radial-gradient(rgba(255,255,255,.4), rgba(255,255,255,.1) 2px, transparent 30px); |
| | background-size: 550px 550px, 350px 350px, 250px 250px, 150px 150px; |
| | background-position: 0 0, 40px 60px, 130px 270px, 70px 100px; |
| | animation: backgroundScroll 60s linear infinite; |
| | } |
| | @keyframes backgroundScroll { |
| | 0% { background-position: 0 0, 40px 60px, 130px 270px, 70px 100px; } |
| | 100% { background-position: 550px 550px, 590px 610px, 680px 820px, 620px 650px; } |
| | } |
| | .container { max-width: 800px; margin: 0 auto; padding: 20px; } |
| | .header { |
| | text-align: center; |
| | margin-bottom: 30px; |
| | font-family: 'Copperplate', fantasy; |
| | color: #ffd700; |
| | text-shadow: 0 0 10px #ffd700, 0 0 20px #ffd700, 0 0 30px #ffd700; |
| | } |
| | .chat-box { |
| | background-color: rgba(42, 42, 42, 0.7); |
| | border-radius: 15px; |
| | padding: 20px; |
| | margin-bottom: 20px; |
| | box-shadow: 0 0 20px rgba(255, 215, 0, 0.3); |
| | } |
| | .user-input { |
| | background-color: rgba(58, 58, 58, 0.8); |
| | border: 2px solid #ffd700; |
| | color: #ffffff; |
| | padding: 10px; |
| | border-radius: 5px; |
| | width: 100%; |
| | transition: all 0.3s ease; |
| | } |
| | .user-input:focus { |
| | box-shadow: 0 0 15px #ffd700; |
| | } |
| | .generate-btn { |
| | background-color: #ffd700; |
| | color: #0f1624; |
| | border: none; |
| | padding: 10px 20px; |
| | border-radius: 5px; |
| | cursor: pointer; |
| | font-weight: bold; |
| | transition: all 0.3s ease; |
| | } |
| | .generate-btn:hover { |
| | background-color: #ffec8b; |
| | transform: scale(1.05); |
| | } |
| | .output-box { |
| | background-color: rgba(42, 42, 42, 0.7); |
| | border-radius: 15px; |
| | padding: 20px; |
| | margin-top: 20px; |
| | min-height: 100px; |
| | border: 1px solid #ffd700; |
| | white-space: pre-wrap; |
| | font-family: 'Georgia', serif; |
| | line-height: 1.6; |
| | box-shadow: inset 0 0 10px rgba(255, 215, 0, 0.3); |
| | } |
| | .gr-slider { |
| | --slider-color: #ffd700; |
| | } |
| | .gr-box { |
| | border-color: #ffd700; |
| | background-color: rgba(42, 42, 42, 0.7); |
| | } |
| | </style> |
| | """ |
| |
|
| | with gr.Blocks(css=css) as demo: |
| | gr.HTML("<div class='header'><h1>🌟 Enchanted Tales Generator 🌟</h1></div>") |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | prompt = gr.Textbox( |
| | placeholder="Begin your magical journey here (e.g., 'In a realm beyond the mists of time...')", |
| | label="Story Incantation", |
| | elem_classes="user-input" |
| | ) |
| | with gr.Column(scale=1): |
| | generate_btn = gr.Button("Weave the Tale", elem_classes="generate-btn") |
| | |
| | with gr.Row(): |
| | max_length = gr.Slider(minimum=50, maximum=500, value=432, step=1, label="Scroll Length") |
| | temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Magical Intensity") |
| | top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Arcane Diversity") |
| | |
| | output = gr.Markdown(elem_classes="output-box") |
| | |
| | generate_btn.click( |
| | gradio_generate, |
| | inputs=[prompt, max_length, temperature, top_k], |
| | outputs=output |
| | ) |
| |
|
| | gr.HTML(""" |
| | <div style="text-align: center; margin-top: 20px; font-style: italic; color: #ffd700;"> |
| | "In the realm of imagination, every word is a spell, every sentence a charm." |
| | </div> |
| | """) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |