Delete finlora_hf_submission/trytry1.py
Browse files- finlora_hf_submission/trytry1.py +0 -208
finlora_hf_submission/trytry1.py
DELETED
|
@@ -1,208 +0,0 @@
|
|
| 1 |
-
# ===== FinLoRA evaluation on LLaMA-3.1-8B (LoRA 4-bit) | JSONL inputs =====
|
| 2 |
-
import os, gc, psutil, json, torch, torch.nn as nn
|
| 3 |
-
from typing import List, Tuple
|
| 4 |
-
from sklearn.metrics import accuracy_score, f1_score
|
| 5 |
-
|
| 6 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 7 |
-
from peft import PeftModel
|
| 8 |
-
|
| 9 |
-
# --------- CONFIG ----------
|
| 10 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
-
|
| 12 |
-
# Use the SAME local LLaMA snapshot you trained with
|
| 13 |
-
BASE_DIR = "d04e592bb4f6aa9cfee91e2e20afa771667e1d4b"
|
| 14 |
-
ADAPTER_DIR = "finlora_lora_ckpt_llama_8bit_r8" # from training
|
| 15 |
-
HEADS_PATH = "finlora_heads_llama_8bit_r8.pt" # from training
|
| 16 |
-
|
| 17 |
-
# Your JSONL eval files
|
| 18 |
-
EVAL_FILES = ["fiqa_test.jsonl", "fpb_test.jsonl"]
|
| 19 |
-
|
| 20 |
-
# Tokenization / eval params
|
| 21 |
-
MAXLEN = 256
|
| 22 |
-
INIT_BATCH = 64 # will auto-shrink on OOM
|
| 23 |
-
|
| 24 |
-
# ---------------- Memory helpers ----------------
|
| 25 |
-
def print_mem(tag: str = ""):
|
| 26 |
-
v = psutil.virtual_memory()
|
| 27 |
-
cpu = f"CPU used: {(v.total - v.available)/1e9:.1f}/{v.total/1e9:.1f} GB"
|
| 28 |
-
if torch.cuda.is_available():
|
| 29 |
-
free, total = torch.cuda.mem_get_info()
|
| 30 |
-
gpu = f"GPU used: {(total - free)/1e9:.1f}/{total/1e9:.1f} GB"
|
| 31 |
-
else:
|
| 32 |
-
gpu = "GPU: n/a"
|
| 33 |
-
print(f"[MEM] {tag} | {cpu} | {gpu}")
|
| 34 |
-
|
| 35 |
-
def memory_guard():
|
| 36 |
-
gc.collect()
|
| 37 |
-
if torch.cuda.is_available():
|
| 38 |
-
torch.cuda.empty_cache()
|
| 39 |
-
torch.cuda.ipc_collect()
|
| 40 |
-
|
| 41 |
-
# ---------------- Label/text helpers ----------------
|
| 42 |
-
LBL_MAP_3 = {
|
| 43 |
-
"-1":0, "neg":0, "negative":0, -1:0,
|
| 44 |
-
"0":1, "neu":1, "neutral":1, 0:1,
|
| 45 |
-
"1":2, "pos":2, "positive":2, 1:2,
|
| 46 |
-
}
|
| 47 |
-
TEXT_KEYS = ["context", "text", "sentence", "content", "Title", "question_title", "Input", "review"]
|
| 48 |
-
LABEL_KEYS = ["label", "sentiment", "Sentiment", "class", "target", "y"]
|
| 49 |
-
|
| 50 |
-
def _find_key(d: dict, candidates: List[str]) -> str:
|
| 51 |
-
keys_lower = {k.lower(): k for k in d.keys()}
|
| 52 |
-
for c in candidates:
|
| 53 |
-
if c in d: return c
|
| 54 |
-
if c.lower() in keys_lower: return keys_lower[c.lower()]
|
| 55 |
-
return None
|
| 56 |
-
|
| 57 |
-
def _norm_label(v) -> int:
|
| 58 |
-
if v is None: return 1
|
| 59 |
-
s = str(v).strip().lower()
|
| 60 |
-
if s in LBL_MAP_3: return LBL_MAP_3[s]
|
| 61 |
-
if s.lstrip("-").isdigit():
|
| 62 |
-
try: return LBL_MAP_3[int(s)]
|
| 63 |
-
except Exception: return 1
|
| 64 |
-
return 1
|
| 65 |
-
|
| 66 |
-
def load_eval_jsonl(path: str) -> Tuple[List[str], List[int]]:
|
| 67 |
-
if not os.path.exists(path):
|
| 68 |
-
raise FileNotFoundError(f"Eval file not found: {path}")
|
| 69 |
-
texts, labels = [], []
|
| 70 |
-
with open(path, "r", encoding="utf-8") as f:
|
| 71 |
-
for line in f:
|
| 72 |
-
line = line.strip()
|
| 73 |
-
if not line: continue
|
| 74 |
-
try:
|
| 75 |
-
ex = json.loads(line)
|
| 76 |
-
except Exception:
|
| 77 |
-
continue
|
| 78 |
-
t_key = _find_key(ex, TEXT_KEYS)
|
| 79 |
-
y_key = _find_key(ex, LABEL_KEYS)
|
| 80 |
-
if t_key is None or y_key is None:
|
| 81 |
-
# try a couple more common fields
|
| 82 |
-
t_key = t_key or _find_key(ex, ["Sentence", "question", "title"])
|
| 83 |
-
y_key = y_key or _find_key(ex, ["Label", "SentimentLabel"])
|
| 84 |
-
if t_key is None or y_key is None:
|
| 85 |
-
continue
|
| 86 |
-
texts.append(str(ex.get(t_key, "")))
|
| 87 |
-
labels.append(_norm_label(ex.get(y_key, None)))
|
| 88 |
-
if not texts:
|
| 89 |
-
raise ValueError(f"No (text,label) rows found in {path}. Check field names.")
|
| 90 |
-
return texts, labels
|
| 91 |
-
|
| 92 |
-
# ---------------- Load LLaMA base + tokenizer (4-bit) ----------------
|
| 93 |
-
print_mem("before load")
|
| 94 |
-
|
| 95 |
-
tok = AutoTokenizer.from_pretrained(BASE_DIR, use_fast=True, trust_remote_code=True)
|
| 96 |
-
if tok.pad_token_id is None:
|
| 97 |
-
tok.pad_token = tok.eos_token
|
| 98 |
-
tok.padding_side = "left"
|
| 99 |
-
|
| 100 |
-
bnb = BitsAndBytesConfig(
|
| 101 |
-
load_in_8bit=True,
|
| 102 |
-
)
|
| 103 |
-
base = AutoModelForCausalLM.from_pretrained(
|
| 104 |
-
BASE_DIR,
|
| 105 |
-
quantization_config=bnb,
|
| 106 |
-
torch_dtype=torch.bfloat16,
|
| 107 |
-
low_cpu_mem_usage=True,
|
| 108 |
-
device_map="auto",
|
| 109 |
-
trust_remote_code=True,
|
| 110 |
-
)
|
| 111 |
-
base.config.use_cache = False
|
| 112 |
-
|
| 113 |
-
print_mem("after base load")
|
| 114 |
-
|
| 115 |
-
# ---------------- Attach LoRA adapters ----------------
|
| 116 |
-
enc = PeftModel.from_pretrained(base, ADAPTER_DIR)
|
| 117 |
-
enc.eval()
|
| 118 |
-
print_mem("after PEFT attach")
|
| 119 |
-
|
| 120 |
-
# ---------------- Rebuild heads & load (256-d proj, 3-way cls) ----------------
|
| 121 |
-
hid = enc.config.hidden_size # LLaMA-3.1-8B -> 4096
|
| 122 |
-
proj = nn.Sequential(nn.Linear(hid, hid), nn.Tanh(), nn.Linear(hid, 256)).to(DEVICE).eval()
|
| 123 |
-
cls = nn.Linear(hid, 3).to(DEVICE).eval()
|
| 124 |
-
|
| 125 |
-
state = torch.load(HEADS_PATH, map_location="cpu")
|
| 126 |
-
# quick shape sanity (weights exist and match hid)
|
| 127 |
-
_ = proj.load_state_dict(state["proj"], strict=True)
|
| 128 |
-
_ = cls.load_state_dict(state["cls"], strict=True)
|
| 129 |
-
|
| 130 |
-
# ---------------- Pooling over LLaMA hidden states ----------------
|
| 131 |
-
@torch.no_grad()
|
| 132 |
-
def _mean_pool(last_hidden_state: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
|
| 133 |
-
mask = attn_mask.unsqueeze(-1).type_as(last_hidden_state) # [B,T,1]
|
| 134 |
-
summed = (last_hidden_state * mask).sum(dim=1) # [B,H]
|
| 135 |
-
denom = mask.sum(dim=1).clamp(min=1e-6) # [B,1]
|
| 136 |
-
return summed / denom
|
| 137 |
-
|
| 138 |
-
# make sure your tokenizer has a pad token & left padding for LLaMA
|
| 139 |
-
if tok.pad_token_id is None:
|
| 140 |
-
tok.pad_token = tok.eos_token
|
| 141 |
-
tok.padding_side = "left"
|
| 142 |
-
|
| 143 |
-
def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 144 |
-
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
|
| 145 |
-
summed = (last_hidden_state * mask).sum(dim=1)
|
| 146 |
-
denom = mask.sum(dim=1).clamp(min=1e-6)
|
| 147 |
-
return summed / denom
|
| 148 |
-
|
| 149 |
-
@torch.inference_mode()
|
| 150 |
-
def encode_cls(batch):
|
| 151 |
-
batch = {k: v.to(DEVICE, non_blocking=True) for k, v in batch.items()}
|
| 152 |
-
# ask the model to return hidden states
|
| 153 |
-
out = enc(**batch, output_hidden_states=True)
|
| 154 |
-
# for causal LM, take the top hidden layer
|
| 155 |
-
last = out.hidden_states[-1] if hasattr(out, "hidden_states") else out[0]
|
| 156 |
-
h = _mean_pool(last, batch["attention_mask"])
|
| 157 |
-
return h
|
| 158 |
-
|
| 159 |
-
@torch.inference_mode()
|
| 160 |
-
def logits_for_texts(texts, maxlen=MAXLEN):
|
| 161 |
-
encd = tok(texts, padding=True, truncation=True, max_length=maxlen, return_tensors="pt")
|
| 162 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available()):
|
| 163 |
-
h = encode_cls(encd)
|
| 164 |
-
return cls(h)
|
| 165 |
-
|
| 166 |
-
# ---------------- OOM-safe evaluation ----------------
|
| 167 |
-
def evaluate_set(texts: List[str], labels: List[int], batch: int = INIT_BATCH, maxlen: int = MAXLEN):
|
| 168 |
-
preds = []
|
| 169 |
-
i, n = 0, len(texts)
|
| 170 |
-
while i < n:
|
| 171 |
-
cur_bs = min(batch, n - i)
|
| 172 |
-
while True:
|
| 173 |
-
try:
|
| 174 |
-
l = logits_for_texts(texts[i:i+cur_bs], maxlen=maxlen)
|
| 175 |
-
preds.extend(l.argmax(dim=1).cpu().tolist())
|
| 176 |
-
break
|
| 177 |
-
except torch.cuda.OutOfMemoryError:
|
| 178 |
-
memory_guard()
|
| 179 |
-
if cur_bs <= 1: raise
|
| 180 |
-
cur_bs = max(1, cur_bs // 2)
|
| 181 |
-
print(f"[OOM] shrinking batch to {cur_bs}")
|
| 182 |
-
except RuntimeError as e:
|
| 183 |
-
if "out of memory" in str(e).lower():
|
| 184 |
-
memory_guard()
|
| 185 |
-
if cur_bs <= 1: raise
|
| 186 |
-
cur_bs = max(1, cur_bs // 2)
|
| 187 |
-
print(f"[OOM] shrinking batch to {cur_bs}")
|
| 188 |
-
else:
|
| 189 |
-
raise
|
| 190 |
-
i += cur_bs
|
| 191 |
-
batch = cur_bs
|
| 192 |
-
return {
|
| 193 |
-
"accuracy": accuracy_score(labels, preds),
|
| 194 |
-
"macro_f1": f1_score(labels, preds, average="macro"),
|
| 195 |
-
}
|
| 196 |
-
|
| 197 |
-
# ---------------- Run JSONL evaluations ----------------
|
| 198 |
-
print_mem("before JSONL eval")
|
| 199 |
-
results = {}
|
| 200 |
-
for jpath in EVAL_FILES:
|
| 201 |
-
texts, labels = load_eval_jsonl(jpath)
|
| 202 |
-
print(f"Loaded {jpath}: {len(texts)} rows")
|
| 203 |
-
metrics = evaluate_set(texts, labels, batch=INIT_BATCH, maxlen=MAXLEN)
|
| 204 |
-
results[jpath] = metrics
|
| 205 |
-
print(f"{jpath} -> Acc: {metrics['accuracy']:.4f} | Macro-F1: {metrics['macro_f1']:.4f}")
|
| 206 |
-
|
| 207 |
-
print("Summary:", results)
|
| 208 |
-
print_mem("done")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|