| class SearchMood: | |
| def __init__(self, mood_prompt, prior_init): | |
| self.prior_init = prior_init | |
| self.mood_prompt = mood_prompt | |
| self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| self.embeding = lambda mood_prompt, mood_state: (self.model.encode(mood_prompt, convert_to_tensor=True), self.model.encode(mood_state, convert_to_tensor=True)) | |
| self.similar = lambda similarx, similary: util.pytorch_cos_sim(similarx, similary) | |
| self.cx_sample = shelve.open('cx_sample.db')['sample'] | |
| self.database = shelve.open('database.db') | |
| SearchMood.prior_component = torch.tensor([0.,1.]) | |
| self.prior_sample = torch.normal(self.prior_component[0], self.prior_component[1], size=(5,)) | |
| self.sample_losses = None | |
| def embedings(self, samplex, sampley): | |
| emb = self.embeding(samplex, sampley) | |
| similarity = self.similar(emb[0], emb[1]) | |
| return(similarity) | |
| def mood_dist(self, data_sample=False, mood_prompt=False, search=True): | |
| cx_index = [] | |
| if search == True: | |
| for mood_state in self.cx_sample: | |
| index_sample = [] | |
| max_sample = 0 | |
| index_sample = 0 | |
| for index, mood_prompts in enumerate(self.database['database']): | |
| simemb = self.embedings(mood_state, mood_prompts) | |
| if max_sample < simemb: | |
| max_sample = simemb | |
| index_sample = index | |
| cx_index.append((float(index_sample))) | |
| else: | |
| cx_index.append(self.embedings(mood_prompt, data_sample)) | |
| return(torch.tensor(cx_index)) | |
| def loss_fn(self): | |
| for sample in self.prior_sample: | |
| sample = sample.item() | |
| data_sample = self.database['database'][round(sample)] | |
| samp_loss = self.mood_dist(data_sample, self.mood_prompt, search=False) | |
| print(samp_loss) | |
| if samp_loss.item() >= 1.: | |
| print('test') | |
| break | |
| return(torch.tensor([samp_loss*-1])) | |
| def search_compose(self): | |
| for d in range(100): | |
| optimizer = optim.Adagrad((self.prior_component[0], self.prior_component[1])) | |
| optimizer.step(closure=self.loss_fn) | |
| state_dict = optimizer.state_dict() | |
| params = state_dict['param_groups'][0]['params'] | |
| self.prior_component[0] = params[0] | |
| self.prior_component[1] = params[1] | |