File size: 1,525 Bytes
c858478
 
 
9e9f576
c858478
9e9f576
c858478
6ccaf53
9e9f576
 
c858478
9e9f576
c858478
9e9f576
c858478
6ccaf53
9e9f576
 
 
 
 
 
 
 
 
 
c858478
6ccaf53
c858478
 
6ccaf53
c858478
 
9e9f576
 
 
 
 
 
 
 
 
 
c858478
9e9f576
 
 
 
6ccaf53
c858478
6ccaf53
9e9f576
 
 
c858478
 
 
 
 
 
 
 
 
 
 
9e9f576
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import cv2
import json
import os
from src.model import RelationshipNet
from huggingface_hub import hf_hub_download


MODEL_REPO = "kalpkanungo/scenegraphnet-relationship-model"
MODEL_FILENAME = "relationship_model.pth"

LABEL_MAP_PATH = "data/relationship_dataset/label_map.json"

device = "cuda" if torch.cuda.is_available() else "cpu"


if os.path.exists(LABEL_MAP_PATH):
    with open(LABEL_MAP_PATH) as f:
        label_map = json.load(f)
else:
    print("⚠️ label_map.json not found, using fallback")
    label_map = {
        "0": "on",
        "1": "next to",
        "2": "under"
    }

inv_map = {v: k for k, v in label_map.items()}
num_classes = len(label_map)


model = RelationshipNet(num_classes)

try:
    model_path = hf_hub_download(
        repo_id=MODEL_REPO,
        filename=MODEL_FILENAME
    )
    print("✅ Model downloaded from Hugging Face")

    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

except Exception as e:
    print(f"⚠️ Failed to load model from HF: {e}")
    model = None


def predict(image):
   
    if model is None:
        return "next to"

    image = cv2.resize(image, (128, 128))
    image = image / 255.0
    image = (image - 0.5) / 0.5

    image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
    image = image.unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image)
        pred = torch.argmax(output, dim=1).item()

    return inv_map.get(pred, "unknown")