from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel
import torch

# ---------------------------------------------------------------
# 1. Load base model + LoRA adapter
# ---------------------------------------------------------------
BASE_MODEL = "Siddharth63/LFM2.5-forestWHY"

processor = AutoProcessor.from_pretrained(BASE_MODEL)

model = AutoModelForImageTextToText.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    dtype=torch.bfloat16,
)

# Optional: merge adapter into base for faster inference (uses more VRAM during merge)
# model = model.merge_and_unload()

# ---------------------------------------------------------------
# 2. ForestWHY prompt constants (copied from training notebook)
# ---------------------------------------------------------------
PANEL_KEYS = [
    "img_rgb_before", "img_rgb_after",
    "img_nir_false_color_before", "img_nir_false_color_after",
    "img_swir_composite_before", "img_swir_composite_after",
    "img_delta_ndvi", "img_delta_nbr",
    "img_attention_multi", "img_embedding_change",
    "img_cropa_roads", "img_delta_attn_role",
    "img_head_disagreement", "img_pca_semantic",
]

SYSTEM_PROMPT = """You are an expert remote sensing analyst and tropical ecologist
specializing in forest cover change detection from Sentinel-2 satellite imagery.
You have access to both raw spectral data AND outputs from a trained I-JEPA
Vision Transformer encoder (ViT-L/8, 24 layers, trained on 1.57M Sentinel-2 patches).

For each observation you receive 14 image panels:
SPECTRAL (1-8): RGB before/after, NIR before/after, SWIR before/after, DELTA_NDVI, DELTA_NBR
JEPA ENCODER (9-14): Multi-scale attention, Embedding change, CroPA roads,
                     Delta attention role, Head disagreement, PCA semantic clusters

Critical rules:
- Provide detailed 10-step reasoning citing specific panel numbers
- JEPA panels supersede spectral panels for ambiguous cases
- Panel 10 embedding change is more reliable than DELTA_NDVI for degradation detection
- Panel 11 CroPA road presence is the strongest predictor of continued deforestation
- Panel 14 PCA cluster split is ground truth for genuine land cover transition
- Write detailed prose for each step - minimum 4 sentences per step"""

def build_user_text(sample):
    return (
        f"Analyze this Sentinel-2 satellite observation showing land cover change.\n\n"
        f"Location: {sample.get('region','unknown')}, "
        f"{sample.get('lat',0):.3f} N {sample.get('lon',0):.3f} E\n"
        f"Period: {sample.get('year_before',2021)} to {sample.get('year_after',2025)} "
        f"({sample.get('year_gap',4)}-year span) | "
        f"Patch: 64x64 px at 19m/px approx 1.2km x 1.2km\n"
        f"Measured: DELTA_NDVI={sample.get('delta_ndvi',0):+.3f}  "
        f"DELTA_NBR={sample.get('delta_nbr',0):+.3f}\n\n"
        f"SPECTRAL PANELS (images 1-8):\n"
        f"  1. RGB Before ({sample.get('year_before',2021)})\n"
        f"  2. RGB After  ({sample.get('year_after',2025)})\n"
        f"  3. NIR False Color Before\n"
        f"  4. NIR False Color After\n"
        f"  5. SWIR Composite Before\n"
        f"  6. SWIR Composite After\n"
        f"  7. DELTA_NDVI Change Map\n"
        f"  8. DELTA_NBR Change Map\n\n"
        f"JEPA ENCODER PANELS (images 9-14) - I-JEPA ViT-L trained on 1.57M S2 patches:\n"
        f"  9.  Multi-Scale Attention (R=fine/roads G=mid/fields B=landscape)\n"
        f"  10. Embedding Change Map (cosine distance before to after per patch)\n"
        f"  11. CroPA Road Map (cross-patch token correlation for linear structures)\n"
        f"  12. Delta Attention Role (red=more salient after, blue=less salient)\n"
        f"  13. Head Disagreement Map (hot=ambiguous, cool=clear semantic signal)\n"
        f"  14. PCA Semantic Clusters (left=before, right=after)\n\n"
        f"Follow the 10-step reasoning protocol. Use JEPA panels to go beyond "
        f"what spectral indices alone can detect."
    )

# ---------------------------------------------------------------
# 3. Run inference on a dataset row (or your own 14 PIL images + metadata dict)
# ---------------------------------------------------------------
from datasets import load_dataset
dataset = load_dataset("Siddharth63/forestwhy-training-v1", split="train")
sample = dataset[0]

images = [sample[k] for k in PANEL_KEYS if sample.get(k) is not None]

conversation = [
    {"role": "system",
     "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
    {"role": "user",
     "content": [{"type": "image", "image": img} for img in images]
                + [{"type": "text", "text": build_user_text(sample)}]},
]

inputs = processor.apply_chat_template(
    conversation,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
    tokenize=True,
).to(model.device)

with torch.inference_mode():
    outputs = model.generate(
        **inputs,
        max_new_tokens=4096,
        do_sample=True,
        temperature=1.0,
        min_p=0.1,
        use_cache=True,
    )

# Decode only the newly generated tokens
gen = outputs[:, inputs["input_ids"].shape[1]:]
print(processor.batch_decode(gen, skip_special_tokens=True)[0])
Downloads last month
44
Safetensors
Model size
2B params
Tensor type
F16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support