SceneGraphNet / src /relationship_infer.py
Kalp Kanungo
Fix: corrected label_map inversion bug
6ccaf53
import torch
import cv2
import json
import os
from src.model import RelationshipNet
from huggingface_hub import hf_hub_download
MODEL_REPO = "kalpkanungo/scenegraphnet-relationship-model"
MODEL_FILENAME = "relationship_model.pth"
LABEL_MAP_PATH = "data/relationship_dataset/label_map.json"
device = "cuda" if torch.cuda.is_available() else "cpu"
if os.path.exists(LABEL_MAP_PATH):
with open(LABEL_MAP_PATH) as f:
label_map = json.load(f)
else:
print("⚠️ label_map.json not found, using fallback")
label_map = {
"0": "on",
"1": "next to",
"2": "under"
}
inv_map = {v: k for k, v in label_map.items()}
num_classes = len(label_map)
model = RelationshipNet(num_classes)
try:
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME
)
print("✅ Model downloaded from Hugging Face")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
except Exception as e:
print(f"⚠️ Failed to load model from HF: {e}")
model = None
def predict(image):
if model is None:
return "next to"
image = cv2.resize(image, (128, 128))
image = image / 255.0
image = (image - 0.5) / 0.5
image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
image = image.unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
pred = torch.argmax(output, dim=1).item()
return inv_map.get(pred, "unknown")