import numpy as np def ensure_2d_spectra(spectra: np.ndarray) -> np.ndarray: arr = np.asarray(spectra) if arr.ndim == 3 and 1 in arr.shape: return arr.reshape(arr.shape[0], -1) if arr.ndim != 2: raise ValueError(f"spectral array must be 2D [N, L], got shape={arr.shape}") return arr def build_target_wavenumbers(target_len: int = 3500) -> np.ndarray: # 3500 samples spanning [0, 3500] cm^-1, matching the requested paper workflow. return np.linspace(0.0, 3500.0, target_len, dtype=np.float32) def preprocess_raman_spectra( spectra: np.ndarray, wavenumbers: np.ndarray, target_len: int = 3500, low_cm: float = 0.0, high_cm: float = 3500.0, eps_fill: float = 1e-8, ): spectra = ensure_2d_spectra(np.asarray(spectra, dtype=np.float32)) wavenumbers = np.asarray(wavenumbers, dtype=np.float32).reshape(-1) target_w = build_target_wavenumbers(target_len=target_len) valid = np.isfinite(wavenumbers) w = wavenumbers[valid] x = spectra[:, valid] if w.size < 2: raise ValueError("wavenumbers must contain at least 2 finite values") order = np.argsort(w) w = w[order] x = x[:, order] in_range = (w >= low_cm) & (w <= high_cm) if np.any(in_range): w = w[in_range] x = x[:, in_range] if w.size < 2: raise ValueError("wavenumbers in [0, 3500] are insufficient for interpolation") w_unique, unique_idx = np.unique(w, return_index=True) x = x[:, unique_idx] interpolated = np.empty((x.shape[0], target_len), dtype=np.float32) for i in range(x.shape[0]): interpolated[i] = np.interp( target_w, w_unique, x[i], left=eps_fill, right=eps_fill, ) mins = interpolated.min(axis=1, keepdims=True) maxs = interpolated.max(axis=1, keepdims=True) denom = np.where((maxs - mins) < 1e-12, 1.0, maxs - mins) normalized = (interpolated - mins) / denom return normalized.astype(np.float32), target_w def preprocess_raman_dataset( spectra: np.ndarray, labels: np.ndarray, wavenumbers: np.ndarray, target_len: int = 3500, low_cm: float = 0.0, high_cm: float = 3500.0, eps_fill: float = 1e-8, ): labels = np.asarray(labels) spectra, target_w = preprocess_raman_spectra( spectra, wavenumbers, target_len=target_len, low_cm=low_cm, high_cm=high_cm, eps_fill=eps_fill, ) if spectra.shape[0] != labels.shape[0]: raise ValueError( f"spectral/labels length mismatch: {spectra.shape[0]} vs {labels.shape[0]}" ) return spectra.astype(np.float32), labels, target_w def augment_small_trainset( x_train: np.ndarray, y_train: np.ndarray, target_per_class: int = 100, seed: int = 42, ) -> tuple[np.ndarray, np.ndarray]: rng = np.random.default_rng(seed) x_train = np.asarray(x_train, dtype=np.float32) y_train = np.asarray(y_train) out_x = [x_train] out_y = [y_train] unique_classes = np.unique(y_train) for cls in unique_classes: cls_idx = np.where(y_train == cls)[0] cls_samples = x_train[cls_idx] if cls_samples.shape[0] >= target_per_class: continue need = target_per_class - cls_samples.shape[0] synth = [] for _ in range(need): src = cls_samples[rng.integers(0, cls_samples.shape[0])].copy() noise = rng.normal(0.0, 0.01, size=src.shape).astype(np.float32) scale = rng.uniform(0.95, 1.05) shift = int(rng.integers(-3, 4)) aug = np.roll(src * scale + noise, shift) aug = np.clip(aug, 0.0, 1.0) synth.append(aug) if synth: synth = np.asarray(synth, dtype=np.float32) out_x.append(synth) out_y.append(np.full((synth.shape[0],), cls, dtype=y_train.dtype)) x_aug = np.concatenate(out_x, axis=0) y_aug = np.concatenate(out_y, axis=0) perm = rng.permutation(len(x_aug)) return x_aug[perm], y_aug[perm]