import os, json, re, gc from pathlib import Path from collections import deque, defaultdict, OrderedDict import torch import torch.nn as nn import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel class LRU: def __init__(self, capacity=2): self.cap = capacity self._data = OrderedDict() def get(self, k): if k in self._data: self._data.move_to_end(k) return self._data[k] return None def put(self, k, v): self._data[k] = v self._data.move_to_end(k) while len(self._data) > self.cap: _, (m, t) = self._data.popitem(last=False) try: del m; del t except: pass torch.cuda.empty_cache(); gc.collect() class MikoEnsemble: def __init__(self, root_dir: str="."): self.root = Path(root_dir) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Paths self.base_dir = self.root / "base_model" self.lora_dir = self.root / "lora_adapters" self.router_dir = self.root / "router" # Metadata with open(self.root / "README_METADATA.json", "r") as f: md = json.load(f) self.style_ids = md.get("style_ids", []) self.num_styles = md.get("num_styles", len(self.style_ids)) # Router backbone (feature extractor) self.router_backbone_id = "Qwen/Qwen3-14B" self.router_model = AutoModelForCausalLM.from_pretrained( self.router_backbone_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) self.router_tokenizer = AutoTokenizer.from_pretrained(self.router_backbone_id) if not self.router_tokenizer.pad_token: self.router_tokenizer.pad_token = self.router_tokenizer.eos_token self.hidden_size = self.router_model.config.hidden_size # Router head (projection + prototypes) state = torch.load(self.router_dir / "router_state.pt", map_location=self.device) self.style_projection = nn.Sequential( nn.Linear(self.hidden_size, 512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, self.num_styles) ).to(self.device, dtype=torch.float32) self.style_projection.load_state_dict(state["style_projection_state"]) self.style_prototypes = state["style_prototypes"].to(self.device, dtype=torch.float32) # Generation cache (small) self.model_cache = LRU(capacity=2) # 4-bit quant config (standard QLoRA options from HF docs) self.bnb4 = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) def _route(self, text: str, temperature: float=0.7) -> int: inp = self.router_tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding=True).to(self.device) with torch.no_grad(): outs = self.router_model(**inp, output_hidden_states=True) last = outs.hidden_states[-1] # [B, T, H] (fp16) attn = inp["attention_mask"] pooled = (last * attn.unsqueeze(-1)).sum(1) / attn.sum(1, keepdim=True) pooled = pooled.to(torch.float32) proto = torch.matmul(pooled, self.style_prototypes.T) logits = self.style_projection(pooled) scores = proto + logits if temperature and temperature > 0: probs = torch.softmax(scores / temperature, dim=-1) sid = torch.multinomial(probs, 1).item() else: sid = torch.argmax(scores, dim=-1).item() return sid def _load_adapter_with_base(self, style_id: int): cache_key = f"style:{style_id}" cached = self.model_cache.get(cache_key) if cached: return cached adir = self.lora_dir / f"style_{style_id}_lora" if not adir.exists(): # Fallback: use base only base = AutoModelForCausalLM.from_pretrained( str(self.base_dir), torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) tok = AutoTokenizer.from_pretrained(str(self.base_dir)) self.model_cache.put(cache_key, (base, tok)) return base, tok # Read adapter_config.json to get the true base model id with open(adir / "adapter_config.json", "r") as f: ac = json.load(f) base_id = ac.get("base_model_name_or_path", "").strip() or "Qwen/Qwen3-14B" # Load base in 4-bit when possible try: base = AutoModelForCausalLM.from_pretrained( base_id if "/" in base_id else str(self.base_dir), device_map="auto", quantization_config=self.bnb4, trust_remote_code=True ) except Exception: base = AutoModelForCausalLM.from_pretrained( base_id if "/" in base_id else str(self.base_dir), torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) tok = AutoTokenizer.from_pretrained(base_id if "/" in base_id else str(self.base_dir)) if not tok.pad_token: tok.pad_token = tok.eos_token # Attach LoRA model = PeftModel.from_pretrained(base, str(adir)) model.eval() self.model_cache.put(cache_key, (model, tok)) return model, tok def generate(self, prompt: str, max_new_tokens: int=120, temperature: float=0.8, router_temperature: float=0.7): # Route sid = self._route(prompt, router_temperature) # Style tag styled = f"{prompt}" model, tok = self._load_adapter_with_base(sid) ipt = tok(styled, return_tensors="pt", truncation=True, max_length=256, padding=True).to(model.device) with torch.no_grad(): out = model.generate( **ipt, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, top_p=0.95, pad_token_id=tok.pad_token_id, eos_token_id=tok.eos_token_id ) txt = tok.decode(out[0], skip_special_tokens=True) txt = txt.replace(styled, "").replace(f"", "").strip() return txt if __name__ == "__main__": ens = MikoEnsemble(".") for p in [ "What’s your take on today’s price action?", "Another overhyped launch on CT. Thoughts?", "How would you reframe this narrative to be bullish?" ]: print("\nPrompt:", p) print("→", ens.generate(p))