Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from model import DistMult | |
| from PIL import Image | |
| from torchvision import transforms | |
| import json | |
| from tqdm import tqdm | |
| # Default image tensor normalization | |
| _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] | |
| _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] | |
| def generate_target_list(data, entity2id): | |
| sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] | |
| sub = list(sub['t']) | |
| categories = [] | |
| for item in tqdm(sub): | |
| if entity2id[str(int(float(item)))] not in categories: | |
| categories.append(entity2id[str(int(float(item)))]) | |
| # print('categories = {}'.format(categories)) | |
| # print("No. of target categories = {}".format(len(categories))) | |
| return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) | |
| # Load necessary data and initialize the model | |
| entity2id = json.load(open('entity2id_subtree.json', 'r')) | |
| id2entity = {v: k for k, v in entity2id.items()} | |
| datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False) | |
| num_ent_id = len(entity2id) | |
| target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere | |
| overall_id_to_name = json.load(open('overall_id_to_name.json')) | |
| # Initialize your model here | |
| model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary | |
| model.eval() | |
| ckpt = torch.load('species_class_model.pt', map_location=torch.device('cpu')) | |
| model.load_state_dict(ckpt['model'], strict=False) | |
| print('ckpt loaded...') | |
| # Define your evaluation function | |
| def evaluate(img): | |
| transform_steps = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((448, 448)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD) | |
| ]) | |
| h = transform_steps(img) | |
| r = torch.tensor([3]) | |
| # Assuming `move_to` is a function to move tensors to the desired device | |
| h = h.unsqueeze(0) | |
| r = r.unsqueeze(0) | |
| outputs = F.softmax(model.forward_ce(h, r, triple_type=('image', 'id')), dim=-1) | |
| # print('outputs = {}'.format(outputs.size())) | |
| predictions = torch.topk(outputs, k=5, dim=-1).indices.squeeze(0).tolist() | |
| # print('predictions', predictions) | |
| result = {} | |
| for i in predictions: | |
| pred_label = target_list[i].item() | |
| label = overall_id_to_name[str(id2entity[pred_label])] | |
| prob = outputs[0, i].item() | |
| result[label] = prob | |
| # y_pred = outputs.argmax(-1).cpu() | |
| # pred_label = target_list[y_pred].item() | |
| # species_label = overall_id_to_name[str(id2entity[pred_label])] | |
| # print('pred_label', pred_label) | |
| # print('species_label', species_label) | |
| # return species_label | |
| return result | |
| # Gradio interface | |
| species_model = gr.Interface( | |
| evaluate, | |
| gr.inputs.Image(shape=(200, 200)), | |
| outputs="label", | |
| title='Camera Trap Species Classification demo', | |
| # description='Species Classification', | |
| # article='Species Classification' | |
| ) | |
| species_model.launch(server_port=8977,share=True, debug=True) | |