File size: 3,229 Bytes
02c783d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import json
from tqdm import tqdm
from loguru import logger
from dataclasses import dataclass, asdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from models.Base import BaseModel
from dataloaders.ProblemState import ProblemState
from memories.Memory import BaseMemory



class BaseAgent:

    def __init__(self, model: BaseModel, dataset):
        self.model = model
        self.dataset = dataset
        self.memories = self.memory_init()

    def memory_init(self):
        return [BaseMemory(ps) for ps in self.dataset.problem_states]
    def run_single_pass(self, mem: BaseMemory, verbose=False, temperature=0):
        pass

    def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, mem_path=None, temperature=0):
        data_len = datalen if datalen else len(self.dataset)
        with tqdm(total=data_len) as pbar:
            if multi_thread:
                thread_num = 3
                
                with ThreadPoolExecutor(max_workers=thread_num) as executor:
                    futures = {executor.submit(self.run_single_pass, mem, temperature): mem for mem in self.memories[:data_len]}
                    for future in as_completed(futures):
                        pbar.update(1)
            else:
                for mem in self.memories[:data_len]:
                    self.run_single_pass(mem, verbose, temperature=temperature)
                    pbar.update(1)
                    
        
        if output_path is not None:
            self.dataset.write_file(output_path)
        
        if mem_path is not None:
            self.write_memories(mem_path)
    
    def write_memories(self, file_path):
        with open(file_path, "w") as f:
            for mem in self.memories:
                output = asdict(mem)
                f.write(json.dumps(output) + "\n")

class SequentialBaseAgent(BaseAgent):    
    def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, iteration_num=0, temperature=0):
        data_len = datalen if datalen else len(self.dataset)
        
        for iter in range(iteration_num):
            logger.info(f"\n=== Iteration {iter} ===")
            root, extension = os.path.splitext(output_path)
            iter_path = f"{root}_{iter}{extension}"
            with tqdm(total=data_len) as pbar:
                if multi_thread:
                    thread_num = 3
                    
                    with ThreadPoolExecutor(max_workers=thread_num) as executor:
                        futures = {executor.submit(self.run_single_pass, mem, temperature): mem for mem in self.memories[:data_len]}
                        for future in as_completed(futures):
                            pbar.update(1)
                        # list(tqdm(executor.map(self.run_single_pass, self.dataset.problem_states[:data_len], [verbose]*data_len)), total=data_len)
                else:
                    for mem in self.memories[:data_len]:
                        self.run_single_pass(mem, verbose, temperature=temperature)
                        pbar.update(1)
                        
            
            if output_path is not None:
                self.dataset.write_file(iter_path)