xrag-compression-probe / probe_clf.py
wexumin's picture
Update probe_clf.py
b8fc425 verified
import json
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from sklearn.preprocessing import StandardScaler
from safetensors.torch import save_file, load_file
class LinearProbeTorch(PyTorchModelHubMixin):
def __init__(
self,
normalize=True,
device="cpu",
random_state=42,
hidden_dim=1024,
):
self.normalize = normalize
self.device = device if torch.cuda.is_available() else "cpu"
self.random_state = random_state
self.hidden_dim = hidden_dim
self.model = None
self.scaler = None
def _create_model(self, input_dim, positive_class_proportion=0.5):
model = nn.Linear(input_dim, 1)
nn.init.zeros_(model.weight)
p = np.clip(positive_class_proportion, 0.01, 0.99)
nn.init.constant_(model.bias, np.log(p / (1 - p)))
return model
def predict_proba(self, X):
if self.model is None:
raise ValueError("Model not fitted yet. Call fit() first.")
X = np.asarray(X.cpu().numpy() if isinstance(X, torch.Tensor) else X)
if self.normalize and self.scaler is not None:
X = self.scaler.transform(X)
self.model.eval()
with torch.no_grad():
logits = self.model(torch.FloatTensor(X).to(self.device)).squeeze(-1)
proba_pos = torch.sigmoid(logits).cpu().numpy()
return np.column_stack([1 - proba_pos, proba_pos])
def predict(self, X):
return (self.predict_proba(X)[:, 1] >= 0.5).astype(int)
def _save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
save_file(
self.model.state_dict(), os.path.join(save_directory, "model.safetensors")
)
config = {
"hidden_dim": self.hidden_dim,
"normalize": self.normalize,
}
with open(os.path.join(save_directory, "config.json"), "w") as f:
json.dump(config, f)
if self.scaler is not None:
with open(os.path.join(save_directory, "scaler.pkl"), "wb") as f:
pickle.dump(self.scaler, f)
@classmethod
def _from_pretrained(
cls,
model_id,
*args,
config=None,
cache_dir=None,
force_download=False,
revision=None,
**kwargs,
):
from huggingface_hub import hf_hub_download
import pickle
weights_path = hf_hub_download(model_id, "model.safetensors",revision=revision)
config_path = hf_hub_download(model_id, "config.json", revision=revision)
with open(config_path) as f:
cfg = json.load(f)
model = cls(**cfg)
state_dict = load_file(weights_path)
input_dim = state_dict["weight"].shape[1]
model.model = model._create_model(input_dim)
model.model.load_state_dict(state_dict)
try:
scaler_path = hf_hub_download(model_id, "scaler.pkl", revision=revision)
with open(scaler_path, "rb") as f:
model.scaler = pickle.load(f)
except:
pass
model.model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.device = device
model.model = model.model.to(device)
return model