JunhanCai's picture
Initial commit with GEMS model and Dockerfile
6918c6b
Raw
History Blame Contribute Delete
4.12 kB
import numpy as np
def ensure_2d_spectra(spectra: np.ndarray) -> np.ndarray:
arr = np.asarray(spectra)
if arr.ndim == 3 and 1 in arr.shape:
return arr.reshape(arr.shape[0], -1)
if arr.ndim != 2:
raise ValueError(f"spectral array must be 2D [N, L], got shape={arr.shape}")
return arr
def build_target_wavenumbers(target_len: int = 3500) -> np.ndarray:
# 3500 samples spanning [0, 3500] cm^-1, matching the requested paper workflow.
return np.linspace(0.0, 3500.0, target_len, dtype=np.float32)
def preprocess_raman_spectra(
spectra: np.ndarray,
wavenumbers: np.ndarray,
target_len: int = 3500,
low_cm: float = 0.0,
high_cm: float = 3500.0,
eps_fill: float = 1e-8,
):
spectra = ensure_2d_spectra(np.asarray(spectra, dtype=np.float32))
wavenumbers = np.asarray(wavenumbers, dtype=np.float32).reshape(-1)
target_w = build_target_wavenumbers(target_len=target_len)
valid = np.isfinite(wavenumbers)
w = wavenumbers[valid]
x = spectra[:, valid]
if w.size < 2:
raise ValueError("wavenumbers must contain at least 2 finite values")
order = np.argsort(w)
w = w[order]
x = x[:, order]
in_range = (w >= low_cm) & (w <= high_cm)
if np.any(in_range):
w = w[in_range]
x = x[:, in_range]
if w.size < 2:
raise ValueError("wavenumbers in [0, 3500] are insufficient for interpolation")
w_unique, unique_idx = np.unique(w, return_index=True)
x = x[:, unique_idx]
interpolated = np.empty((x.shape[0], target_len), dtype=np.float32)
for i in range(x.shape[0]):
interpolated[i] = np.interp(
target_w,
w_unique,
x[i],
left=eps_fill,
right=eps_fill,
)
mins = interpolated.min(axis=1, keepdims=True)
maxs = interpolated.max(axis=1, keepdims=True)
denom = np.where((maxs - mins) < 1e-12, 1.0, maxs - mins)
normalized = (interpolated - mins) / denom
return normalized.astype(np.float32), target_w
def preprocess_raman_dataset(
spectra: np.ndarray,
labels: np.ndarray,
wavenumbers: np.ndarray,
target_len: int = 3500,
low_cm: float = 0.0,
high_cm: float = 3500.0,
eps_fill: float = 1e-8,
):
labels = np.asarray(labels)
spectra, target_w = preprocess_raman_spectra(
spectra,
wavenumbers,
target_len=target_len,
low_cm=low_cm,
high_cm=high_cm,
eps_fill=eps_fill,
)
if spectra.shape[0] != labels.shape[0]:
raise ValueError(
f"spectral/labels length mismatch: {spectra.shape[0]} vs {labels.shape[0]}"
)
return spectra.astype(np.float32), labels, target_w
def augment_small_trainset(
x_train: np.ndarray,
y_train: np.ndarray,
target_per_class: int = 100,
seed: int = 42,
) -> tuple[np.ndarray, np.ndarray]:
rng = np.random.default_rng(seed)
x_train = np.asarray(x_train, dtype=np.float32)
y_train = np.asarray(y_train)
out_x = [x_train]
out_y = [y_train]
unique_classes = np.unique(y_train)
for cls in unique_classes:
cls_idx = np.where(y_train == cls)[0]
cls_samples = x_train[cls_idx]
if cls_samples.shape[0] >= target_per_class:
continue
need = target_per_class - cls_samples.shape[0]
synth = []
for _ in range(need):
src = cls_samples[rng.integers(0, cls_samples.shape[0])].copy()
noise = rng.normal(0.0, 0.01, size=src.shape).astype(np.float32)
scale = rng.uniform(0.95, 1.05)
shift = int(rng.integers(-3, 4))
aug = np.roll(src * scale + noise, shift)
aug = np.clip(aug, 0.0, 1.0)
synth.append(aug)
if synth:
synth = np.asarray(synth, dtype=np.float32)
out_x.append(synth)
out_y.append(np.full((synth.shape[0],), cls, dtype=y_train.dtype))
x_aug = np.concatenate(out_x, axis=0)
y_aug = np.concatenate(out_y, axis=0)
perm = rng.permutation(len(x_aug))
return x_aug[perm], y_aug[perm]