projectmiko's picture
Upload folder using huggingface_hub
cb4fd0c verified
import os, json, re, gc
from pathlib import Path
from collections import deque, defaultdict, OrderedDict
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
class LRU:
def __init__(self, capacity=2):
self.cap = capacity
self._data = OrderedDict()
def get(self, k):
if k in self._data:
self._data.move_to_end(k)
return self._data[k]
return None
def put(self, k, v):
self._data[k] = v
self._data.move_to_end(k)
while len(self._data) > self.cap:
_, (m, t) = self._data.popitem(last=False)
try:
del m; del t
except: pass
torch.cuda.empty_cache(); gc.collect()
class MikoEnsemble:
def __init__(self, root_dir: str="."):
self.root = Path(root_dir)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Paths
self.base_dir = self.root / "base_model"
self.lora_dir = self.root / "lora_adapters"
self.router_dir = self.root / "router"
# Metadata
with open(self.root / "README_METADATA.json", "r") as f:
md = json.load(f)
self.style_ids = md.get("style_ids", [])
self.num_styles = md.get("num_styles", len(self.style_ids))
# Router backbone (feature extractor)
self.router_backbone_id = "Qwen/Qwen3-14B"
self.router_model = AutoModelForCausalLM.from_pretrained(
self.router_backbone_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
self.router_tokenizer = AutoTokenizer.from_pretrained(self.router_backbone_id)
if not self.router_tokenizer.pad_token:
self.router_tokenizer.pad_token = self.router_tokenizer.eos_token
self.hidden_size = self.router_model.config.hidden_size
# Router head (projection + prototypes)
state = torch.load(self.router_dir / "router_state.pt", map_location=self.device)
self.style_projection = nn.Sequential(
nn.Linear(self.hidden_size, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, self.num_styles)
).to(self.device, dtype=torch.float32)
self.style_projection.load_state_dict(state["style_projection_state"])
self.style_prototypes = state["style_prototypes"].to(self.device, dtype=torch.float32)
# Generation cache (small)
self.model_cache = LRU(capacity=2)
# 4-bit quant config (standard QLoRA options from HF docs)
self.bnb4 = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
def _route(self, text: str, temperature: float=0.7) -> int:
inp = self.router_tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding=True).to(self.device)
with torch.no_grad():
outs = self.router_model(**inp, output_hidden_states=True)
last = outs.hidden_states[-1] # [B, T, H] (fp16)
attn = inp["attention_mask"]
pooled = (last * attn.unsqueeze(-1)).sum(1) / attn.sum(1, keepdim=True)
pooled = pooled.to(torch.float32)
proto = torch.matmul(pooled, self.style_prototypes.T)
logits = self.style_projection(pooled)
scores = proto + logits
if temperature and temperature > 0:
probs = torch.softmax(scores / temperature, dim=-1)
sid = torch.multinomial(probs, 1).item()
else:
sid = torch.argmax(scores, dim=-1).item()
return sid
def _load_adapter_with_base(self, style_id: int):
cache_key = f"style:{style_id}"
cached = self.model_cache.get(cache_key)
if cached:
return cached
adir = self.lora_dir / f"style_{style_id}_lora"
if not adir.exists():
# Fallback: use base only
base = AutoModelForCausalLM.from_pretrained(
str(self.base_dir),
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
tok = AutoTokenizer.from_pretrained(str(self.base_dir))
self.model_cache.put(cache_key, (base, tok))
return base, tok
# Read adapter_config.json to get the true base model id
with open(adir / "adapter_config.json", "r") as f:
ac = json.load(f)
base_id = ac.get("base_model_name_or_path", "").strip() or "Qwen/Qwen3-14B"
# Load base in 4-bit when possible
try:
base = AutoModelForCausalLM.from_pretrained(
base_id if "/" in base_id else str(self.base_dir),
device_map="auto",
quantization_config=self.bnb4,
trust_remote_code=True
)
except Exception:
base = AutoModelForCausalLM.from_pretrained(
base_id if "/" in base_id else str(self.base_dir),
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
tok = AutoTokenizer.from_pretrained(base_id if "/" in base_id else str(self.base_dir))
if not tok.pad_token:
tok.pad_token = tok.eos_token
# Attach LoRA
model = PeftModel.from_pretrained(base, str(adir))
model.eval()
self.model_cache.put(cache_key, (model, tok))
return model, tok
def generate(self, prompt: str, max_new_tokens: int=120, temperature: float=0.8, router_temperature: float=0.7):
# Route
sid = self._route(prompt, router_temperature)
# Style tag
styled = f"<style_{sid}>{prompt}"
model, tok = self._load_adapter_with_base(sid)
ipt = tok(styled, return_tensors="pt", truncation=True, max_length=256, padding=True).to(model.device)
with torch.no_grad():
out = model.generate(
**ipt,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
top_p=0.95,
pad_token_id=tok.pad_token_id,
eos_token_id=tok.eos_token_id
)
txt = tok.decode(out[0], skip_special_tokens=True)
txt = txt.replace(styled, "").replace(f"</style_{sid}>", "").strip()
return txt
if __name__ == "__main__":
ens = MikoEnsemble(".")
for p in [
"What’s your take on today’s price action?",
"Another overhyped launch on CT. Thoughts?",
"How would you reframe this narrative to be bullish?"
]:
print("\nPrompt:", p)
print("→", ens.generate(p))