|
|
import os, json, re, gc |
|
|
from pathlib import Path |
|
|
from collections import deque, defaultdict, OrderedDict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
from peft import PeftModel |
|
|
|
|
|
class LRU: |
|
|
def __init__(self, capacity=2): |
|
|
self.cap = capacity |
|
|
self._data = OrderedDict() |
|
|
def get(self, k): |
|
|
if k in self._data: |
|
|
self._data.move_to_end(k) |
|
|
return self._data[k] |
|
|
return None |
|
|
def put(self, k, v): |
|
|
self._data[k] = v |
|
|
self._data.move_to_end(k) |
|
|
while len(self._data) > self.cap: |
|
|
_, (m, t) = self._data.popitem(last=False) |
|
|
try: |
|
|
del m; del t |
|
|
except: pass |
|
|
torch.cuda.empty_cache(); gc.collect() |
|
|
|
|
|
class MikoEnsemble: |
|
|
def __init__(self, root_dir: str="."): |
|
|
self.root = Path(root_dir) |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.base_dir = self.root / "base_model" |
|
|
self.lora_dir = self.root / "lora_adapters" |
|
|
self.router_dir = self.root / "router" |
|
|
|
|
|
|
|
|
with open(self.root / "README_METADATA.json", "r") as f: |
|
|
md = json.load(f) |
|
|
self.style_ids = md.get("style_ids", []) |
|
|
self.num_styles = md.get("num_styles", len(self.style_ids)) |
|
|
|
|
|
|
|
|
self.router_backbone_id = "Qwen/Qwen3-14B" |
|
|
self.router_model = AutoModelForCausalLM.from_pretrained( |
|
|
self.router_backbone_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.router_tokenizer = AutoTokenizer.from_pretrained(self.router_backbone_id) |
|
|
if not self.router_tokenizer.pad_token: |
|
|
self.router_tokenizer.pad_token = self.router_tokenizer.eos_token |
|
|
self.hidden_size = self.router_model.config.hidden_size |
|
|
|
|
|
|
|
|
state = torch.load(self.router_dir / "router_state.pt", map_location=self.device) |
|
|
self.style_projection = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, 512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(512, 256), |
|
|
nn.ReLU(), |
|
|
nn.Linear(256, self.num_styles) |
|
|
).to(self.device, dtype=torch.float32) |
|
|
self.style_projection.load_state_dict(state["style_projection_state"]) |
|
|
self.style_prototypes = state["style_prototypes"].to(self.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
self.model_cache = LRU(capacity=2) |
|
|
|
|
|
|
|
|
self.bnb4 = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
def _route(self, text: str, temperature: float=0.7) -> int: |
|
|
inp = self.router_tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding=True).to(self.device) |
|
|
with torch.no_grad(): |
|
|
outs = self.router_model(**inp, output_hidden_states=True) |
|
|
last = outs.hidden_states[-1] |
|
|
attn = inp["attention_mask"] |
|
|
pooled = (last * attn.unsqueeze(-1)).sum(1) / attn.sum(1, keepdim=True) |
|
|
pooled = pooled.to(torch.float32) |
|
|
proto = torch.matmul(pooled, self.style_prototypes.T) |
|
|
logits = self.style_projection(pooled) |
|
|
scores = proto + logits |
|
|
if temperature and temperature > 0: |
|
|
probs = torch.softmax(scores / temperature, dim=-1) |
|
|
sid = torch.multinomial(probs, 1).item() |
|
|
else: |
|
|
sid = torch.argmax(scores, dim=-1).item() |
|
|
return sid |
|
|
|
|
|
def _load_adapter_with_base(self, style_id: int): |
|
|
cache_key = f"style:{style_id}" |
|
|
cached = self.model_cache.get(cache_key) |
|
|
if cached: |
|
|
return cached |
|
|
|
|
|
adir = self.lora_dir / f"style_{style_id}_lora" |
|
|
if not adir.exists(): |
|
|
|
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
str(self.base_dir), |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
tok = AutoTokenizer.from_pretrained(str(self.base_dir)) |
|
|
self.model_cache.put(cache_key, (base, tok)) |
|
|
return base, tok |
|
|
|
|
|
|
|
|
with open(adir / "adapter_config.json", "r") as f: |
|
|
ac = json.load(f) |
|
|
base_id = ac.get("base_model_name_or_path", "").strip() or "Qwen/Qwen3-14B" |
|
|
|
|
|
|
|
|
try: |
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
base_id if "/" in base_id else str(self.base_dir), |
|
|
device_map="auto", |
|
|
quantization_config=self.bnb4, |
|
|
trust_remote_code=True |
|
|
) |
|
|
except Exception: |
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
base_id if "/" in base_id else str(self.base_dir), |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(base_id if "/" in base_id else str(self.base_dir)) |
|
|
if not tok.pad_token: |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained(base, str(adir)) |
|
|
model.eval() |
|
|
|
|
|
self.model_cache.put(cache_key, (model, tok)) |
|
|
return model, tok |
|
|
|
|
|
def generate(self, prompt: str, max_new_tokens: int=120, temperature: float=0.8, router_temperature: float=0.7): |
|
|
|
|
|
sid = self._route(prompt, router_temperature) |
|
|
|
|
|
|
|
|
styled = f"<style_{sid}>{prompt}" |
|
|
model, tok = self._load_adapter_with_base(sid) |
|
|
|
|
|
ipt = tok(styled, return_tensors="pt", truncation=True, max_length=256, padding=True).to(model.device) |
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**ipt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
top_p=0.95, |
|
|
pad_token_id=tok.pad_token_id, |
|
|
eos_token_id=tok.eos_token_id |
|
|
) |
|
|
txt = tok.decode(out[0], skip_special_tokens=True) |
|
|
txt = txt.replace(styled, "").replace(f"</style_{sid}>", "").strip() |
|
|
return txt |
|
|
|
|
|
if __name__ == "__main__": |
|
|
ens = MikoEnsemble(".") |
|
|
for p in [ |
|
|
"What’s your take on today’s price action?", |
|
|
"Another overhyped launch on CT. Thoughts?", |
|
|
"How would you reframe this narrative to be bullish?" |
|
|
]: |
|
|
print("\nPrompt:", p) |
|
|
print("→", ens.generate(p)) |
|
|
|