# Copyright 2023-2024 The SapientML Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import threading from contextvars import ContextVar from pathlib import Path from uuid import UUID from sapientml import SapientML from .utils import convert_int64 class GenerateCodeThread(threading.Thread): def __init__( self, sml: SapientML, config: dict, log_handler: logging.Handler, ctx_uuid: ContextVar[UUID], uuid: UUID, ): self.sml = sml self.config = config self.result = None self.exception = None self.log_handler = log_handler self.ctx_uuid = ctx_uuid self.uuid = uuid threading.Thread.__init__(self) def run(self): try: self.ctx_uuid.set(self.uuid) fit_args = ["save_datasets_format", "csv_encoding", "ignore_columns", "output_dir", "test_data"] self.sml.fit(self.config["training_dataframe"], **({k: v for k, v in self.config.items() if k in fit_args})) output_dir = self.config["output_dir"] if not Path(output_dir / "final_script_code_explainability.json").exists(): script_code_explainability = self.sml.generator._best_pipeline.pipeline_json with open(output_dir / "script_code_explainability.json", "w") as f: json.dump(script_code_explainability, f, ensure_ascii=False, indent=2) candidates = self.sml.generator._candidate_scripts elements = [t[0] for t in candidates] for i in range(3): # explainability = with open(output_dir / f"{i+1}_script_code_explainability.json", "w") as f: json.dump(elements[i].pipeline_json, f, ensure_ascii=False, indent=2) if not Path(output_dir / ".skeleton.json").exists(): skeleton = self.sml.generator._best_pipeline.labels with open(output_dir / ".skeleton.json", "w") as f: json.dump(convert_int64(skeleton), f, ensure_ascii=False, indent=2) except Exception as e: self.exception = e finally: pass def get_result(self): return self.result def get_exception(self): return self.exception def get_sml(self): return self.sml def trigger_cancel(self): self.cancel_token.isTriggered = True