Spaces:
Running
Running
| import torch | |
| from citekit.prompt.prompt import Prompt | |
| import re | |
| from citekit.utils.utils import one_paragraph, first_sentence, make_as | |
| import random | |
| import os | |
| class Module: | |
| module_count = 1 | |
| def __init__(self,prompt_maker: Prompt = None, pipeline = None, self_prompt = {}, iterative = False, merge = False, max_turn =6, output_as = None, parallel = False) -> None: | |
| self.self_prompt = self_prompt | |
| self.use_head_prompt = True | |
| self.connect_to(pipeline) | |
| self.prompt_maker = prompt_maker | |
| self.last_message = '' | |
| self.destinations = [] | |
| self.conditions = {} | |
| self.head_key = None | |
| self.parallel = parallel | |
| self.iterative = iterative | |
| self.merge = merge | |
| self.head_process = one_paragraph | |
| self.max_turn = max_turn | |
| self.multi_process = False | |
| self.output_cond = {} # {cond : {'post_processing':post, 'end':end}} | |
| self.count = Module.module_count | |
| Module.module_count += 1 | |
| self.if_add_output_to_head = False | |
| self.turns = 0 | |
| self.end = False | |
| def __str__(self) -> str: | |
| if self.model_type: | |
| return f'{self.model_type}-[{self.count}]' | |
| else: | |
| return f'Unknown-type module-[{self.count}]' | |
| def get_json_config(self, config): | |
| print('get_json_config:',config) | |
| avaliable_mapping = { | |
| 'max turn': 'max_turn', | |
| 'prompt': 'prompt', | |
| 'destination': 'destination', | |
| 'global prompt': 'head_key', | |
| } | |
| if config == 'prompt': | |
| prompt_info = { | |
| 'template': self.prompt_maker.template, | |
| 'components': self.prompt_maker.components | |
| } | |
| self_info = self.self_prompt | |
| return { | |
| 'prompt_info': prompt_info, | |
| 'self_info': self_info | |
| } | |
| elif config == 'destination': | |
| return { | |
| 'destination': str(self.destinations[0]) | |
| } | |
| elif config in ['max turn','global prompt']: | |
| config = avaliable_mapping[config] | |
| print('getting the config:',config) | |
| return getattr(self, config) | |
| else: | |
| raise NotImplementedError(f'get_json_config for {config} is not implemented') | |
| def get_destinations(self): | |
| return self.destinations | |
| def update(self, config, update_info): | |
| if config == 'prompt': | |
| template = update_info['template'] | |
| components = update_info['components'] | |
| self_prompt = update_info['self_prompt'] | |
| import copy | |
| # avoid changing the original prompt_maker | |
| self.prompt_maker = copy.deepcopy(self.prompt_maker) | |
| self.prompt_maker.update(template=template, components=components) | |
| self.self_prompt = self_prompt | |
| elif config == 'destination': | |
| print('update destination:',update_info[0], 'post_processing:',update_info[1]) | |
| if update_info[1] == 'None': | |
| self.set_target(update_info[0]) | |
| else: | |
| self.set_target(update_info[0], post_processing=make_as(update_info[1])) | |
| elif config == 'delete_destination': | |
| for i, d in enumerate(self.destinations): | |
| if str(d) == str(update_info): | |
| self.destinations.remove(d) | |
| del self.conditions[d] | |
| break | |
| elif config == 'header': | |
| self.add_to_head(update_info, sub = True) | |
| elif config == 'max turn': | |
| self.max_turn = update_info | |
| else: | |
| raise NotImplementedError(f'update for {config} is not implemented') | |
| def end_multi(self): | |
| return | |
| def set_use_head_prompt(self,use): | |
| assert isinstance(use,bool) | |
| self.use_head_prompt = use | |
| def reset(self): | |
| self.end = False | |
| self.turns = 0 | |
| def change_to_multi_process(self,bool_value): | |
| if bool_value: | |
| self.last_message = [] | |
| else: | |
| self.last_message = '' | |
| self.multi_process = bool_value | |
| def get_use_head_prompt(self): | |
| return self.use_head_prompt | |
| def generate(self, head_prompt: dict = {}, dynamic_prompt: dict = {}): | |
| raise NotImplementedError | |
| def send(self): | |
| for destination in self.destinations: | |
| cond = self.conditions[destination]['condition'] | |
| if cond(self): | |
| return destination | |
| return None | |
| def set_target(self,destination, condition = lambda self: True, post_processing = lambda x:x) -> None: | |
| self.conditions[destination] = {'condition': condition, 'post_processing' : post_processing} | |
| self.destinations = [destination] + self.destinations | |
| destination.connect_to(self.pipeline) | |
| def clear_destination(self): | |
| self.destinations = [] | |
| self.conditions = {} | |
| def add_output_to_head(self, outputs): | |
| if self.if_add_output_to_head: | |
| if not self.head_sub: | |
| if self.head_key not in self.pipeline.head.keys(): | |
| self.pipeline.head.update({self.head_key: self.head_process(outputs)}) | |
| else: | |
| self.pipeline.head[self.head_key] += '\n' | |
| self.pipeline.head[self.head_key] += self.head_process(outputs) | |
| else: | |
| self.pipeline.head[self.head_key] = self.head_process(outputs) | |
| def connect_to(self, pipeline = None) -> None: | |
| self.pipeline = pipeline | |
| if pipeline: | |
| pipeline.module.append(self) | |
| def output(self): | |
| outed = False | |
| for cond, post_and_end in self.output_cond.items(): | |
| if cond(self): | |
| if not outed: | |
| if not self.merge: | |
| self.pipeline.output.append(post_and_end['post_processing'](self.last_message)) | |
| else: | |
| self.pipeline.output.append(post_and_end['post_processing'](''.join(self.last_message))) | |
| outed = True | |
| if post_and_end['end']: | |
| self.end = True | |
| def set_output(self, cond = lambda self: True, post_processing = lambda x:x, end = True): | |
| self.output_cond[cond] = {'post_processing': post_processing, 'end' : end} | |
| def get_first_module(self): | |
| return self | |
| def add_to_head(self, datakey, sub = False, process = None): | |
| self.if_add_output_to_head = True | |
| self.head_key = datakey | |
| self.head_sub = sub | |
| if process: | |
| self.head_process = process | |
| def load_model(model_name_or_path,dtype = torch.float16): | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| torch_dtype=dtype, | |
| device_map='auto', | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| model.eval() | |
| return model, tokenizer | |
| class LLM(Module): | |
| model_type = 'Generator' | |
| def __init__(self, model = None, prompt_maker: Prompt =None, pipeline = None, post_processing = None, self_prompt = {}, device = 'cpu',temperature = 0.5 ,stop = None, max_turn = 6, share_model_with = None, iterative = False, auto_cite = False, output = None,merge = False, noisy = True, parallel = False, output_as ='Answer', auto_cite_from = 'docs') -> None: | |
| super().__init__(prompt_maker,pipeline,self_prompt, iterative, merge, parallel = parallel) | |
| self.max_turn = max_turn | |
| if post_processing: | |
| self.post_processing = post_processing | |
| else: | |
| self.post_processing = lambda x: {output_as:x} | |
| if model: | |
| self.model_name = model | |
| self.stop = stop | |
| self.multi_process = False | |
| self.noisy = noisy | |
| self.head_process = one_paragraph | |
| self.auto_cite = auto_cite | |
| if auto_cite: | |
| self.cite_from = auto_cite_from | |
| if model: | |
| if 'gpt' not in model.lower(): | |
| if not share_model_with: | |
| print('loading model...') | |
| self.model, self.tokenizer = self.load_model(model) | |
| else: | |
| print('sharing model...') | |
| self.model, self.tokenizer = share_model_with.model, share_model_with.tokenizer | |
| self.temperature = temperature | |
| self.device = device | |
| else: | |
| self.openai_key = os.getenv('OPENAI_API_KEY') | |
| self.output_cond = {} # {cond : {'post_processing':post, 'end':end}} | |
| self.if_add_output_to_head = False | |
| self.token_used = 0 | |
| def reset(self): | |
| self.end = False | |
| self.turns = 0 | |
| self.token_used = 0 | |
| def __str__(self) -> str: | |
| if self.model_name: | |
| return f'{self.model_name}-[{self.count}]' | |
| else: | |
| return 'unknown model' | |
| def __repr__(self) -> str: | |
| return (f'{self.prompt_maker}\n|\n|\nV\n{self}\n|\n|\nV\n'+ '/'.join([str(des) for des in self.destinations]+['output'])) | |
| def load_model(self, model_name_or_path,dtype = torch.float16): | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| torch_dtype=dtype, | |
| device_map='auto', | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| model.eval() | |
| return model, tokenizer | |
| def set_cite(self,key): | |
| self.cite_from = key | |
| self.auto_cite = True | |
| def generate_content(self, prompt): | |
| if 'gpt' in self.model_name.lower(): | |
| import openai | |
| openai.api_key = self.openai_key | |
| prompt = [ | |
| {'role': 'system', | |
| 'content': "You are a good helper who follow the instructions"}, | |
| {'role': 'user', 'content': prompt} | |
| ] | |
| response = openai.ChatCompletion.create( | |
| model=self.model_name, | |
| messages=prompt, | |
| max_tokens=500, | |
| stop = self.stop | |
| ) | |
| self.token_used += response['usage']['completion_tokens'] + response['usage']['prompt_tokens'] | |
| return response['choices'][0]['message']['content'] | |
| else: | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| stop = [] if self.stop is None else self.stop | |
| outputs = self.model.generate( | |
| **inputs, | |
| do_sample = True, | |
| max_new_tokens = 200, | |
| temperature = self.temperature | |
| ) | |
| self.token_used += len(outputs[0]) | |
| outputs = self.tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True) | |
| return one_paragraph(outputs) | |
| print(outputs) | |
| def generate(self, head_prompt: dict = {}, dynamic_prompt: dict = {}): | |
| if self.use_head_prompt: | |
| #print(head_prompt,self.self_prompt,dynamic_prompt) | |
| prompt = self.prompt_maker(head_prompt,self.self_prompt,dynamic_prompt) | |
| else: | |
| prompt = self.prompt_maker(self.self_prompt,dynamic_prompt) | |
| if self.noisy: | |
| print(f'prompt to {str(self)}:\n',prompt,'\n\n') | |
| self.turns += 1 | |
| outputs = self.generate_content(prompt) | |
| #print('DEBUG:',outputs) | |
| if self.noisy: | |
| print('OUTPUT:') | |
| print(outputs) | |
| if self.auto_cite: | |
| outputs = self.cite_from_prompt({**head_prompt,**self.self_prompt,**dynamic_prompt},outputs) | |
| if self.multi_process: | |
| self.last_message.append(outputs) | |
| else: | |
| self.last_message = outputs | |
| self.add_output_to_head(outputs) | |
| destination = self.send() | |
| if self.turns > self.max_turn: | |
| self.end = True | |
| if destination in self.conditions: | |
| return self.conditions[destination]['post_processing'](outputs) | |
| else: | |
| return self.post_processing(outputs) | |
| def add_output_to_head(self, outputs): | |
| if self.if_add_output_to_head: | |
| if not self.head_sub: | |
| if self.head_key not in self.pipeline.head.keys(): | |
| self.pipeline.head.update({self.head_key: self.head_process(outputs)}) | |
| else: | |
| self.pipeline.head[self.head_key] += '\n' | |
| self.pipeline.head[self.head_key] += self.head_process(outputs) | |
| else: | |
| self.pipeline.head[self.head_key] = self.head_process(outputs) | |
| def output(self): | |
| outed = False | |
| for cond, post_and_end in self.output_cond.items(): | |
| if cond(self): | |
| if not outed: | |
| if not self.merge and not self.iterative: | |
| self.pipeline.output.append(post_and_end['post_processing'](self.last_message)) | |
| else: | |
| self.pipeline.output.append(post_and_end['post_processing'](' '.join(self.last_message))) | |
| outed = True | |
| if post_and_end['end']: | |
| self.end = True | |
| def set_output(self, cond = lambda self: True, post_processing = lambda x:x, end = True): | |
| self.output_cond[cond] = {'post_processing': post_processing, 'end' : end} | |
| def cite_from_prompt(self,prompt_dict,input): | |
| input = first_sentence(input) | |
| cite_docs = prompt_dict[self.cite_from] | |
| refs = re.findall(r'\[\d+\]', cite_docs) | |
| pattern = r'([.!?])\s*$' | |
| if refs: | |
| cite = ''.join(refs) | |
| else: | |
| cite = '' | |
| output = re.sub(pattern, rf' {cite}\1 ', input) | |
| if cite not in output: | |
| output += cite | |
| return output | |
| def add_to_head(self, datakey, sub = False, process = None): | |
| self.if_add_output_to_head = True | |
| self.head_key = datakey | |
| self.head_sub = sub | |
| if process: | |
| self.head_process = process | |
| class TestLLM(LLM): | |
| def __init__(self, model='gpt-4', prompt_maker: Prompt = None, pipeline=None, post_processing=lambda x: x, self_prompt={}, device='cpu', temperature=0.5, stop=None, max_turn=6,share_model_with = None, iterative= False, ans = None) -> None: | |
| super().__init__(model,prompt_maker,pipeline,self_prompt=self_prompt,share_model_with=share_model_with,iterative=iterative) | |
| self.max_turn = max_turn | |
| self.post_processing = post_processing | |
| self.model_name = model | |
| self.last_message = '' | |
| self.stop = stop | |
| self.output_cond = {} # {cond : {'post_processing':post, 'end':end}} | |
| self.if_add_output_to_head = False | |
| self.token_used = 0 | |
| self.ans = 'Strain[1], turns:, heat[2][4]. Sent2[5]. Sent3.\n\n rdd' if not ans else ans | |
| def generate_content(self, prompt): | |
| return self.ans | |
| class AutoAISLLM(LLM): | |
| def __init__(self, model=None, prompt_maker: Prompt = None, pipeline=None, post_processing=None, self_prompt={}, device='cpu', temperature=0.5, stop=None, max_turn=6, share_model_with=None, iterative=False, auto_cite=False, output=None, merge=False, noisy=False, output_as='Answer') -> None: | |
| super().__init__(model, prompt_maker, pipeline, post_processing, self_prompt, device, temperature, stop, max_turn, share_model_with, iterative, auto_cite, output, merge, noisy, output_as) | |
| self.prompt_maker = Prompt('<INST><premise><claim>\n Answer: ',components={ | |
| 'INST':'{INST}\n\n', | |
| 'premise':'Premise: {premise}\n\n', | |
| 'claim':'Claim: {claim}\n', | |
| }) | |
| self.self_prompt={'INST': 'In this task, you will be presented a premise and a claim. If the premise entails the claim, output "1", otherwise output "1". Your answer should only contains one number without any other letters and punctuations.'} | |
| def generate(self, premise, claim): | |
| dict_answer = super().generate({'premise':premise,'claim':claim}) | |
| return dict_answer.get('Answer') | |
| if __name__ == '__main__': | |
| prompt = Prompt(template='<INST><Question><Docs><feedback><Answer>',components={'INST':'{INST}\n\n', | |
| 'Question':'Question:{Question}\n\n', | |
| 'Docs':'{Docs}\n', | |
| 'feedback':'Here is the feed back of your last response:{feedback}\n', | |
| 'Answer':'Here is answer and you have to give feedback:{Answer}'}) | |
| m = LLM('gpt') |