Spaces:
Running
Running
| import torch | |
| import cv2 | |
| import json | |
| import os | |
| from src.model import RelationshipNet | |
| from huggingface_hub import hf_hub_download | |
| MODEL_REPO = "kalpkanungo/scenegraphnet-relationship-model" | |
| MODEL_FILENAME = "relationship_model.pth" | |
| LABEL_MAP_PATH = "data/relationship_dataset/label_map.json" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if os.path.exists(LABEL_MAP_PATH): | |
| with open(LABEL_MAP_PATH) as f: | |
| label_map = json.load(f) | |
| else: | |
| print("⚠️ label_map.json not found, using fallback") | |
| label_map = { | |
| "0": "on", | |
| "1": "next to", | |
| "2": "under" | |
| } | |
| inv_map = {v: k for k, v in label_map.items()} | |
| num_classes = len(label_map) | |
| model = RelationshipNet(num_classes) | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILENAME | |
| ) | |
| print("✅ Model downloaded from Hugging Face") | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| except Exception as e: | |
| print(f"⚠️ Failed to load model from HF: {e}") | |
| model = None | |
| def predict(image): | |
| if model is None: | |
| return "next to" | |
| image = cv2.resize(image, (128, 128)) | |
| image = image / 255.0 | |
| image = (image - 0.5) / 0.5 | |
| image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) | |
| image = image.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image) | |
| pred = torch.argmax(output, dim=1).item() | |
| return inv_map.get(pred, "unknown") |