Spaces:
Runtime error
Runtime error
| ''' | |
| Author: [egrt] | |
| Date: 2022-08-14 09:37:12 | |
| LastEditors: [egrt] | |
| LastEditTime: 2022-08-17 20:34:36 | |
| Description: | |
| ''' | |
| import numpy as np | |
| import pandas as pd | |
| import pickle | |
| from sklearn.preprocessing import LabelEncoder | |
| def show_config(**kwargs): | |
| print('Configurations:') | |
| print('-' * 70) | |
| print('|%25s | %40s|' % ('keys', 'values')) | |
| print('-' * 70) | |
| for key, value in kwargs.items(): | |
| print('|%25s | %40s|' % (str(key), str(value))) | |
| print('-' * 70) | |
| #--------------------------------------------# | |
| # 使用自己训练好的模型预测需要修改3个参数 | |
| # model_path和classes_path和backbone都需要修改! | |
| #--------------------------------------------# | |
| class Classification(object): | |
| _defaults = { | |
| #--------------------------------------------------------------------------# | |
| # 使用自己训练好的模型进行预测一定要修改model_path和classes_path! | |
| # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt | |
| # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 | |
| #--------------------------------------------------------------------------# | |
| "model_path" : 'model_data/automl_v2.pkl', | |
| "train_path" : 'datasets/archive/artworks.csv', | |
| #-------------------------------# | |
| # 是否使用Cuda | |
| # 没有GPU可以设置成False | |
| #-------------------------------# | |
| "cuda" : False | |
| } | |
| def get_defaults(cls, n): | |
| if n in cls._defaults: | |
| return cls._defaults[n] | |
| else: | |
| return "Unrecognized attribute name '" + n + "'" | |
| #---------------------------------------------------# | |
| # 初始化classification | |
| #---------------------------------------------------# | |
| def __init__(self, **kwargs): | |
| self.__dict__.update(self._defaults) | |
| for name, value in kwargs.items(): | |
| setattr(self, name, value) | |
| #---------------------------------------------------# | |
| # 获得种类 | |
| #---------------------------------------------------# | |
| self.num_classes = 1 | |
| self.train_data = pd.read_csv(self.train_path) | |
| self.generate() | |
| show_config(**self._defaults) | |
| #---------------------------------------------------# | |
| # 获得所有的分类 | |
| #---------------------------------------------------# | |
| def generate(self): | |
| #---------------------------------------------------# | |
| # 载入模型与权值 | |
| #---------------------------------------------------# | |
| with open("model_data/automl_v2.pkl", "rb") as f: | |
| self.automl = pickle.load(f) | |
| def detect_one(self, name, date, level, classification, height, width): | |
| # 读取数据集 | |
| train_data = self.train_data | |
| ArtistID = train_data.loc[train_data["Name"] == name, "Artist ID"][0] | |
| # 对输入数据进行编码 | |
| la_Catalogue = LabelEncoder() | |
| la_Catalogue.fit(train_data["Catalogue"]) | |
| Catalogue = la_Catalogue.transform(["Y"]) | |
| la_Department = LabelEncoder() | |
| la_Department.fit(train_data["Department"]) | |
| Department = la_Department.transform([level]) | |
| la_Classification = LabelEncoder() | |
| la_Classification.fit(train_data["Classification"]) | |
| Classification = la_Classification.transform([classification]) | |
| test_dict = {'Artist ID':ArtistID,'Date':date,'Department':Department,'Classification':Classification, | |
| "Height (cm)":height, "Width (cm)":width} | |
| test_data = pd.DataFrame(test_dict) | |
| pred = self.automl.predict(test_data) | |
| return int(pred[0]) | |
| if __name__ == "__main__": | |
| classfication = Classification() | |
| # classfication.get_result() | |
| classfication.detect_one("陈冠夫", 1975, '国家级', '中国山水画', 50, 50) |