Spaces:
Running
Running
File size: 1,525 Bytes
c858478 9e9f576 c858478 9e9f576 c858478 6ccaf53 9e9f576 c858478 9e9f576 c858478 9e9f576 c858478 6ccaf53 9e9f576 c858478 6ccaf53 c858478 6ccaf53 c858478 9e9f576 c858478 9e9f576 6ccaf53 c858478 6ccaf53 9e9f576 c858478 9e9f576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import torch
import cv2
import json
import os
from src.model import RelationshipNet
from huggingface_hub import hf_hub_download
MODEL_REPO = "kalpkanungo/scenegraphnet-relationship-model"
MODEL_FILENAME = "relationship_model.pth"
LABEL_MAP_PATH = "data/relationship_dataset/label_map.json"
device = "cuda" if torch.cuda.is_available() else "cpu"
if os.path.exists(LABEL_MAP_PATH):
with open(LABEL_MAP_PATH) as f:
label_map = json.load(f)
else:
print("⚠️ label_map.json not found, using fallback")
label_map = {
"0": "on",
"1": "next to",
"2": "under"
}
inv_map = {v: k for k, v in label_map.items()}
num_classes = len(label_map)
model = RelationshipNet(num_classes)
try:
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME
)
print("✅ Model downloaded from Hugging Face")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
except Exception as e:
print(f"⚠️ Failed to load model from HF: {e}")
model = None
def predict(image):
if model is None:
return "next to"
image = cv2.resize(image, (128, 128))
image = image / 255.0
image = (image - 0.5) / 0.5
image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
image = image.unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
pred = torch.argmax(output, dim=1).item()
return inv_map.get(pred, "unknown") |