| | from argparse import ArgumentParser |
| |
|
| | from lagent.llms import HFTransformer |
| | from lagent.llms.meta_template import INTERNLM2_META as META |
| |
|
| |
|
| | def parse_args(): |
| | parser = ArgumentParser(description='chatbot') |
| | parser.add_argument( |
| | '--path', |
| | type=str, |
| | default='internlm/internlm2-chat-20b', |
| | help='The path to the model') |
| | parser.add_argument( |
| | '--mode', |
| | type=str, |
| | default='chat', |
| | help='Completion through chat or generate') |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | |
| | model = HFTransformer( |
| | path=args.path, |
| | meta_template=META, |
| | max_new_tokens=1024, |
| | top_p=0.8, |
| | top_k=None, |
| | temperature=0.1, |
| | repetition_penalty=1.0, |
| | stop_words=['<|im_end|>']) |
| |
|
| | def input_prompt(): |
| | print('\ndouble enter to end input >>> ', end='', flush=True) |
| | sentinel = '' |
| | return '\n'.join(iter(input, sentinel)) |
| |
|
| | history = [] |
| | while True: |
| | try: |
| | prompt = input_prompt() |
| | except UnicodeDecodeError: |
| | print('UnicodeDecodeError') |
| | continue |
| | if prompt == 'exit': |
| | exit(0) |
| | history.append(dict(role='user', content=prompt)) |
| | if args.mode == 'generate': |
| | history = [dict(role='user', content=prompt)] |
| | print('\nInternLm2:', end='') |
| | current_length = 0 |
| | for status, response, _ in model.stream_chat(history): |
| | print(response[current_length:], end='', flush=True) |
| | current_length = len(response) |
| | history.append(dict(role='assistant', content=response)) |
| | print('') |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|