| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import torch |
| import torch.nn.functional as F |
| import gradio as gr |
| from transformers import PreTrainedTokenizerFast |
| from safetensors.torch import load_file |
|
|
| |
| from ncn_architecture.config import NCNConfig |
| from ncn_architecture.model import ModulatedLLM |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| MODEL_CACHE = {} |
| TOKENIZER_CACHE = {} |
|
|
| def load_and_configure_model(model_choice): |
| """ |
| Dynamically reconstructs model configurations from weight shapes |
| to guarantee absolute compatibility before loading. |
| """ |
| if model_choice in MODEL_CACHE: |
| return MODEL_CACHE[model_choice], TOKENIZER_CACHE[model_choice] |
|
|
| if model_choice == "NCN 2M (TinyStories)": |
| weights_path = "models/ncn_2m_tinystories/model.safetensors" |
| tokenizer_path = "models/ncn_2m_tinystories/tokenizer.json" |
| |
| if not os.path.exists(weights_path) or not os.path.exists(tokenizer_path): |
| raise FileNotFoundError("Model weight file or tokenizer file could not be located in the specified path.") |
|
|
| |
| state_dict = load_file(weights_path) |
|
|
| |
| vocab_size, d_model = state_dict["token_embeddings.weight"].shape |
| max_position_embeddings = state_dict["position_embeddings.weight"].shape[0] |
| dim_feedforward = state_dict["transformer_layers.0.feed_forward.linear1.weight"].shape[0] |
| |
| |
| layer_indices = set() |
| for key in state_dict.keys(): |
| if key.startswith("transformer_layers."): |
| layer_indices.add(int(key.split(".")[1])) |
| num_layers = len(layer_indices) if layer_indices else 12 |
|
|
| |
| nhead = 12 |
| if "ncn.layer2.bias" in state_dict: |
| bias_length = state_dict["ncn.layer2.bias"].shape[0] |
| try: |
| |
| nhead = int(((bias_length / num_layers) - 1) / 2) |
| except Exception: |
| nhead = 12 |
|
|
| |
| ncn_hidden_dim = 128 |
| if "ncn.layer1.weight" in state_dict: |
| ncn_hidden_dim = state_dict["ncn.layer1.weight"].shape[0] |
|
|
| |
| config = NCNConfig( |
| vocab_size=vocab_size, |
| d_model=d_model, |
| nhead=nhead, |
| num_layers=num_layers, |
| dim_feedforward=dim_feedforward, |
| max_position_embeddings=max_position_embeddings, |
| ncn_hidden_dim=ncn_hidden_dim |
| ) |
|
|
| |
| model = ModulatedLLM(config) |
| model.load_state_dict(state_dict, strict=True) |
| model.to(device) |
| model.eval() |
|
|
| |
| tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| MODEL_CACHE[model_choice] = model |
| TOKENIZER_CACHE[model_choice] = tokenizer |
|
|
| return model, tokenizer |
|
|
| raise ValueError("Selected model profile is not registered.") |
|
|
|
|
| @torch.no_grad() |
| def generate_text(model_choice, prompt, max_new_tokens, temperature, top_p, top_k, repetition_penalty): |
| """ |
| Autoregressive inference loop with standard KV caching and logits filtering. |
| """ |
| if not prompt.strip(): |
| return "Please enter a starting prompt to begin generating a story." |
|
|
| try: |
| model, tokenizer = load_and_configure_model(model_choice) |
| except Exception as e: |
| return f"Error loading model: {str(e)}" |
|
|
| |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
| generated_ids = input_ids.clone() |
| |
| past_key_values = None |
| past_rnn_state = None |
|
|
| for _ in range(max_new_tokens): |
| |
| if past_key_values is None: |
| outputs = model( |
| input_ids=generated_ids, |
| past_key_values=None, |
| use_cache=True, |
| past_rnn_state=None |
| ) |
| else: |
| outputs = model( |
| input_ids=generated_ids[:, -1:], |
| past_key_values=past_key_values, |
| use_cache=True, |
| past_rnn_state=past_rnn_state |
| ) |
|
|
| logits, past_key_values, _, past_rnn_state = outputs |
| next_token_logits = logits[:, -1, :] |
|
|
| |
| if repetition_penalty != 1.0: |
| for batch_idx in range(next_token_logits.shape[0]): |
| for prev_token_id in set(generated_ids[batch_idx].tolist()): |
| logit = next_token_logits[batch_idx, prev_token_id] |
| if logit < 0: |
| next_token_logits[batch_idx, prev_token_id] = logit * repetition_penalty |
| else: |
| next_token_logits[batch_idx, prev_token_id] = logit / repetition_penalty |
|
|
| |
| if temperature == 0.0: |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| else: |
| next_token_logits = next_token_logits / temperature |
|
|
| |
| if top_k > 0: |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1:] |
| next_token_logits[indices_to_remove] = float("-inf") |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
| |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| next_token_logits[indices_to_remove] = float("-inf") |
|
|
| |
| probs = F.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| generated_ids = torch.cat([generated_ids, next_token], dim=-1) |
|
|
| |
| if next_token.item() == tokenizer.eos_token_id: |
| break |
|
|
| |
| return tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
| |
|
|
| css = """ |
| footer {visibility: hidden} |
| .primary-btn {background-color: #5c4ff2 !important; color: white !important;} |
| .clear-btn {background-color: #374151 !important; color: white !important;} |
| """ |
|
|
| with gr.Blocks(title="Michael Morgan Model Catalogue", css=css) as demo: |
| gr.Markdown("# Michael Morgan Model Catalogue") |
| gr.Markdown("Select a model, enter a story starter, and adjust the generation settings.") |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| model_dropdown = gr.Dropdown( |
| choices=["NCN 2M (TinyStories)"], |
| value="NCN 2M (TinyStories)", |
| label="Model", |
| interactive=True |
| ) |
| |
| prompt_input = gr.Textbox( |
| lines=5, |
| placeholder="Type your story starter here...", |
| label="Story starter" |
| ) |
| |
| |
| with gr.Accordion("Generation settings", open=False): |
| max_tokens = gr.Slider( |
| minimum=1, |
| maximum=512, |
| value=128, |
| step=1, |
| label="Max new tokens" |
| ) |
| temperature = gr.Slider( |
| minimum=0.0, |
| maximum=1.5, |
| value=0.7, |
| step=0.05, |
| label="Temperature (0 = greedy)" |
| ) |
| top_p = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.8, |
| step=0.05, |
| label="Top-p" |
| ) |
| top_k = gr.Slider( |
| minimum=1, |
| maximum=100, |
| value=25, |
| step=1, |
| label="Top-k" |
| ) |
| rep_penalty = gr.Slider( |
| minimum=1.0, |
| maximum=2.0, |
| value=1.1, |
| step=0.05, |
| label="Repetition penalty" |
| ) |
|
|
| with gr.Row(): |
| clear_btn = gr.Button("Clear", elem_classes=["clear-btn"]) |
| generate_btn = gr.Button("Generate", variant="primary", elem_classes=["primary-btn"]) |
|
|
| |
| with gr.Column(scale=1): |
| output_display = gr.Textbox( |
| lines=12, |
| placeholder="The generated story will appear here...", |
| label="Generated story", |
| interactive=False |
| ) |
|
|
| |
| gr.Markdown("### Try these examples") |
| gr.Examples( |
| examples=[ |
| ["Once upon a time, there was a little dragon who"], |
| ["Lily found a tiny wooden key buried in the sand box. She wondered what"], |
| ["One sunny morning, a big friendly dog named Max decided to"], |
| ["Tom had a bright yellow balloon. When he let go of the string, the balloon"] |
| ], |
| inputs=prompt_input |
| ) |
|
|
| |
| gr.Markdown( |
| "Model: SupraLabs/NCN-2M-TinyStories | License: Apache 2.0 | CPU-only | © 2026 Michael Morgan" |
| ) |
|
|
| |
| generate_btn.click( |
| fn=generate_text, |
| inputs=[model_dropdown, prompt_input, max_tokens, temperature, top_p, top_k, rep_penalty], |
| outputs=output_display |
| ) |
| |
| |
| clear_btn.click( |
| fn=lambda: ("", ""), |
| inputs=None, |
| outputs=[prompt_input, output_display] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |