#!/usr/bin/env python3 """ AETHER-Net 0.8B — Inference Test Space Private 모델을 로드하여 텍스트 생성을 테스트합니다. HF Space: T4 GPU, HF_TOKEN secret 필요 Deploy: FINAL-Bench/aether-net-test """ import os import sys import time import json import torch import torch.nn.functional as F import gradio as gr from pathlib import Path from huggingface_hub import hf_hub_download, snapshot_download # ── Config ── MODEL_REPO = "FINAL-Bench/AETHER-Net-0.8B" DONOR_REPO = "Qwen/Qwen3.5-0.8B" # For tokenizer HF_TOKEN = os.getenv("HF_TOKEN") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {DEVICE}") print(f"HF_TOKEN: {'set' if HF_TOKEN else 'NOT SET'}") # ── Download model weights from private repo ── print(f"Downloading AETHER-Net weights from {MODEL_REPO}...") model_dir = None try: model_dir = snapshot_download( MODEL_REPO, token=HF_TOKEN, allow_patterns=["model.safetensors", "config.json"], ) print(f" Model downloaded to: {model_dir}") except Exception as e: print(f" Download failed: {e}") # Source files are co-located in the same directory APP_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, APP_DIR) # ── Load model ── MODEL = None TOKENIZER = None def load_model(): global MODEL, TOKENIZER if MODEL is not None: return True # Load tokenizer from donor print("Loading tokenizer...") from transformers import AutoTokenizer try: TOKENIZER = AutoTokenizer.from_pretrained( DONOR_REPO, trust_remote_code=True, token=HF_TOKEN ) print(f" Tokenizer loaded: vocab_size={TOKENIZER.vocab_size}") except Exception as e: print(f" Tokenizer failed: {e}") return False # Load AETHER-Net print("Loading AETHER-Net model...") try: from config import AetherNetConfig from model import AetherNetModel # Load config config_path = Path(model_dir) / "config.json" if model_dir else None if config_path and config_path.exists(): with open(config_path) as f: cfg_dict = json.load(f) # Filter valid fields valid_fields = {k for k in AetherNetConfig.__dataclass_fields__} filtered = {k: v for k, v in cfg_dict.items() if k in valid_fields} config = AetherNetConfig(**filtered) print(f" Config loaded: hidden={config.hidden_size}, layers={config.num_layers}") else: print(" No config.json, using defaults") config = AetherNetConfig( hidden_size=1024, intermediate_size=3584, num_layers=25, num_attention_heads=16, num_kv_heads=2, head_dim=64, vocab_size=248320, max_position_embeddings=4096, expert_intermediate_size=716, overcome_gate_hidden=64, sliding_window_size=1024, gdn_state_size=64, mamba2_state_size=64, tie_word_embeddings=True, ) model = AetherNetModel(config) # Load weights weights_path = Path(model_dir) / "model.safetensors" if model_dir else None if weights_path and weights_path.exists(): from safetensors.torch import load_file state = load_file(str(weights_path), device="cpu") model.load_state_dict(state, strict=False) print(f" Weights loaded: {len(state)} tensors") else: print(" ⚠️ No weights found, using random init") model = model.to(DEVICE).eval() MODEL = model params = sum(p.numel() for p in model.parameters()) mem = params * 2 / 1e9 # BF16 estimate print(f" Model ready: {params:,} params (~{mem:.1f}GB)") return True except Exception as e: import traceback print(f" Model load failed: {e}") traceback.print_exc() return False # ── Generation ── @torch.no_grad() def generate(prompt, max_tokens=128, temperature=0.8, top_k=50, top_p=0.9): """Generate text from prompt.""" if MODEL is None: success = load_model() if not success: return "❌ Model failed to load. Check logs." # Tokenize input_ids = TOKENIZER.encode(prompt, return_tensors="pt").to(DEVICE) generated = input_ids.clone() t0 = time.time() for i in range(max_tokens): # Truncate to max position if generated.shape[1] > 4096: generated = generated[:, -4096:] outputs = MODEL(input_ids=generated) logits = outputs["logits"][:, -1, :] # Temperature if temperature > 0: logits = logits / temperature # Top-k if top_k > 0: values, _ = torch.topk(logits, top_k) min_val = values[:, -1].unsqueeze(-1) logits = torch.where(logits < min_val, torch.full_like(logits, -float('inf')), logits) # Top-p (nucleus) if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) mask = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p sorted_logits[mask] = -float('inf') logits = sorted_logits.scatter(1, sorted_indices, sorted_logits) probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=-1) # EOS check if next_token.item() == TOKENIZER.eos_token_id: break elapsed = time.time() - t0 tokens_generated = generated.shape[1] - input_ids.shape[1] tps = tokens_generated / elapsed if elapsed > 0 else 0 output_text = TOKENIZER.decode(generated[0], skip_special_tokens=True) stats = f"\n\n---\n📊 {tokens_generated} tokens | {tps:.1f} tok/s | {elapsed:.2f}s" return output_text + stats def get_model_info(): """Return model architecture info.""" if MODEL is None: load_model() if MODEL is None: return "Model not loaded" info = "## AETHER-Net 0.8B — Architecture Info\n\n" info += f"| Item | Value |\n|---|---|\n" info += f"| Device | {DEVICE} |\n" info += f"| Parameters | {sum(p.numel() for p in MODEL.parameters()):,} |\n" info += f"| Layers | {len(MODEL.layers)} |\n" info += f"| Vocab | {MODEL.config.vocab_size:,} |\n" info += f"| Hidden | {MODEL.config.hidden_size} |\n" # Layer types from config import LAYER_TYPES, LAYER_TO_ELEMENT, ELEMENTS info += f"\n### Layer Map\n\n" info += "| Layer | Type | Element |\n|---|---|---|\n" for i in range(len(MODEL.layers)): lt = LAYER_TYPES[i] elem = LAYER_TO_ELEMENT[i] info += f"| {i} | {lt.upper()} | {elem} |\n" # Oheng status info += f"\n### Oheng Status\n\n" for elem in ELEMENTS: layers = [i for i in range(25) if LAYER_TO_ELEMENT[i] == elem] alphas = [] for li in layers: gb = MODEL.layers[li].moe.generate_boost if gb is not None: a = torch.sigmoid(gb.alpha).detach() eidx = ELEMENTS.index(elem) if eidx < a.shape[0]: alphas.append(a[eidx].item()) avg = sum(alphas) / len(alphas) if alphas else 0 info += f"- {elem}: α={avg:.4f}\n" return info # ── Gradio UI ── TITLE = """
Cross-Architecture Knowledge Distillation from Qwen3.5-0.8B
5×5 Magic Square | Oheng MoE | 5 Attention Types