File size: 1,296 Bytes
e13de2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
example_usage.py — Minimal inference with BeMAE-Hα.

This script loads a checkpoint from HuggingFace Hub, retrieves a spectrum
de test depuis le dataset BeSS-foundation, et calcule son embedding z_halpha.
"""
import json
import sys

import numpy as np
import torch
from huggingface_hub import snapshot_download

MODEL_ID = "anonym-submit-26/bemae-halpha-v1"

# 1. Download the snapshot
ckpt_dir = snapshot_download(MODEL_ID)
sys.path.insert(0, ckpt_dir)

from model import SpectralEncoderHalpha, ModelConfig  # noqa: E402

# 2. Charger la config et les poids
with open(f"{ckpt_dir}/config.json") as f:
    meta = json.load(f)
cfg = ModelConfig(**meta["model_config"])

encoder = SpectralEncoderHalpha(cfg)
state = torch.load(f"{ckpt_dir}/pytorch_model.bin", map_location="cpu")
encoder.load_state_dict(state)
encoder.eval()

# 3. Prepare a dummy spectrum (Hα-centred, 128 bins, pseudo-continuum at 1.0)
B = 4
flux = torch.ones(B, 128) + 0.3 * torch.randn(B, 128) * 0.01
wavelengths = torch.linspace(6512.8, 6612.8, 128).unsqueeze(0).expand(B, -1)
validity = torch.ones(B, 128)

# 4. Inference
with torch.no_grad():
    z_halpha, *_ = encoder(flux, wavelengths, validity, mask=None)

print(f"z_halpha shape : {tuple(z_halpha.shape)}")
print(f"z_halpha[0, :8] : {z_halpha[0, :8].numpy()}")