Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| def load_model(model_path): | |
| # Architektur aufbauen | |
| model = models.resnet50(pretrained=False) | |
| model.fc = nn.Linear(2048, 228) | |
| # State Dict laden | |
| state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
| # Keys ggf. anpassen | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("predictor."): | |
| new_k = k.replace("predictor.", "") | |
| else: | |
| new_k = k | |
| new_state_dict[new_k] = v | |
| model.load_state_dict(new_state_dict) | |
| model.eval() | |
| return model | |
| def predict_attributes(model, input_tensor): | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| prediction = torch.sigmoid(output).squeeze().numpy() | |
| threshold = 0.5 | |
| predicted_indices = [i for i, p in enumerate(prediction) if p > threshold] | |
| return predicted_indices | |