from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel
import torch
# ---------------------------------------------------------------
# 1. Load base model + LoRA adapter
# ---------------------------------------------------------------
BASE_MODEL = "Siddharth63/LFM2.5-forestWHY"
processor = AutoProcessor.from_pretrained(BASE_MODEL)
model = AutoModelForImageTextToText.from_pretrained(
BASE_MODEL,
device_map="auto",
dtype=torch.bfloat16,
)
# Optional: merge adapter into base for faster inference (uses more VRAM during merge)
# model = model.merge_and_unload()
# ---------------------------------------------------------------
# 2. ForestWHY prompt constants (copied from training notebook)
# ---------------------------------------------------------------
PANEL_KEYS = [
"img_rgb_before", "img_rgb_after",
"img_nir_false_color_before", "img_nir_false_color_after",
"img_swir_composite_before", "img_swir_composite_after",
"img_delta_ndvi", "img_delta_nbr",
"img_attention_multi", "img_embedding_change",
"img_cropa_roads", "img_delta_attn_role",
"img_head_disagreement", "img_pca_semantic",
]
SYSTEM_PROMPT = """You are an expert remote sensing analyst and tropical ecologist
specializing in forest cover change detection from Sentinel-2 satellite imagery.
You have access to both raw spectral data AND outputs from a trained I-JEPA
Vision Transformer encoder (ViT-L/8, 24 layers, trained on 1.57M Sentinel-2 patches).
For each observation you receive 14 image panels:
SPECTRAL (1-8): RGB before/after, NIR before/after, SWIR before/after, DELTA_NDVI, DELTA_NBR
JEPA ENCODER (9-14): Multi-scale attention, Embedding change, CroPA roads,
Delta attention role, Head disagreement, PCA semantic clusters
Critical rules:
- Provide detailed 10-step reasoning citing specific panel numbers
- JEPA panels supersede spectral panels for ambiguous cases
- Panel 10 embedding change is more reliable than DELTA_NDVI for degradation detection
- Panel 11 CroPA road presence is the strongest predictor of continued deforestation
- Panel 14 PCA cluster split is ground truth for genuine land cover transition
- Write detailed prose for each step - minimum 4 sentences per step"""
def build_user_text(sample):
return (
f"Analyze this Sentinel-2 satellite observation showing land cover change.\n\n"
f"Location: {sample.get('region','unknown')}, "
f"{sample.get('lat',0):.3f} N {sample.get('lon',0):.3f} E\n"
f"Period: {sample.get('year_before',2021)} to {sample.get('year_after',2025)} "
f"({sample.get('year_gap',4)}-year span) | "
f"Patch: 64x64 px at 19m/px approx 1.2km x 1.2km\n"
f"Measured: DELTA_NDVI={sample.get('delta_ndvi',0):+.3f} "
f"DELTA_NBR={sample.get('delta_nbr',0):+.3f}\n\n"
f"SPECTRAL PANELS (images 1-8):\n"
f" 1. RGB Before ({sample.get('year_before',2021)})\n"
f" 2. RGB After ({sample.get('year_after',2025)})\n"
f" 3. NIR False Color Before\n"
f" 4. NIR False Color After\n"
f" 5. SWIR Composite Before\n"
f" 6. SWIR Composite After\n"
f" 7. DELTA_NDVI Change Map\n"
f" 8. DELTA_NBR Change Map\n\n"
f"JEPA ENCODER PANELS (images 9-14) - I-JEPA ViT-L trained on 1.57M S2 patches:\n"
f" 9. Multi-Scale Attention (R=fine/roads G=mid/fields B=landscape)\n"
f" 10. Embedding Change Map (cosine distance before to after per patch)\n"
f" 11. CroPA Road Map (cross-patch token correlation for linear structures)\n"
f" 12. Delta Attention Role (red=more salient after, blue=less salient)\n"
f" 13. Head Disagreement Map (hot=ambiguous, cool=clear semantic signal)\n"
f" 14. PCA Semantic Clusters (left=before, right=after)\n\n"
f"Follow the 10-step reasoning protocol. Use JEPA panels to go beyond "
f"what spectral indices alone can detect."
)
# ---------------------------------------------------------------
# 3. Run inference on a dataset row (or your own 14 PIL images + metadata dict)
# ---------------------------------------------------------------
from datasets import load_dataset
dataset = load_dataset("Siddharth63/forestwhy-training-v1", split="train")
sample = dataset[0]
images = [sample[k] for k in PANEL_KEYS if sample.get(k) is not None]
conversation = [
{"role": "system",
"content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user",
"content": [{"type": "image", "image": img} for img in images]
+ [{"type": "text", "text": build_user_text(sample)}]},
]
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
tokenize=True,
).to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=4096,
do_sample=True,
temperature=1.0,
min_p=0.1,
use_cache=True,
)
# Decode only the newly generated tokens
gen = outputs[:, inputs["input_ids"].shape[1]:]
print(processor.batch_decode(gen, skip_special_tokens=True)[0])