Spaces:
Build error
Build error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import os | |
| class BaseTrainer: | |
| def __init__(self, opt): | |
| self.opt = opt | |
| if self.opt["cuda"] == True: | |
| self.use_cuda = True | |
| print("Using Cuda\n") | |
| else: | |
| self.use_cuda = False | |
| print("Using CPU\n") | |
| self.is_official = "OFFICIAL" in self.opt | |
| self.opt["logFile"] = "log.txt" | |
| self.saveFolder = None | |
| self.logFileHandle = None | |
| self.tb_writer = None | |
| def log(self, s): | |
| # In official case, the program does not output logs | |
| if self.is_official: | |
| return | |
| try: | |
| if self.logFileHandle is None: | |
| self.logFileHandle = open( | |
| os.path.join(self.saveFolder, self.opt["logFile"]), "a" | |
| ) | |
| self.logFileHandle.write(s + "\n") | |
| except Exception as e: | |
| print("ERROR while writing log file:", e) | |
| print(s) | |
| def getSaveFolder(self): | |
| runid = 1 | |
| while True: | |
| saveFolder = os.path.join( | |
| self.opt["datadir"], | |
| self.opt["basename"] + "_conf~", | |
| "run_" + str(runid), | |
| ) | |
| if not os.path.exists(saveFolder): | |
| self.saveFolder = saveFolder | |
| os.makedirs(self.saveFolder) | |
| print("Saving logs, model and evaluation in " + self.saveFolder) | |
| return | |
| runid = runid + 1 | |
| # save copy of conf file | |
| def saveConf(self): | |
| # with open(self.opt['confFile'], encoding='utf-8') as f: | |
| # with open(os.path.join(self.saveFolder, 'conf_copy.tsv'), 'w', encoding='utf-8') as fw: | |
| # for line in f: | |
| # fw.write(line) | |
| with open( | |
| os.path.join(self.saveFolder, "conf_copy.tsv"), "w", encoding="utf-8" | |
| ) as fw: | |
| for k in self.opt: | |
| fw.write("{0}\t{1}\n".format(k, self.opt[k])) | |
| def train(self): | |
| pass | |
| def load(self): | |
| pass | |