Spaces:
Paused
Paused
| import copy | |
| import torch | |
| from evaluate_params import eval_func_param_names | |
| from gen import get_score_model, get_model, evaluate, check_locals | |
| from prompter import non_hf_types | |
| from utils import clear_torch_cache, NullContext, get_kwargs | |
| def run_cli( # for local function: | |
| base_model=None, lora_weights=None, inference_server=None, | |
| debug=None, | |
| examples=None, memory_restriction_level=None, | |
| # for get_model: | |
| score_model=None, load_8bit=None, load_4bit=None, low_bit_mode=None, load_half=None, | |
| load_gptq=None, load_exllama=None, use_safetensors=None, revision=None, | |
| use_gpu_id=None, tokenizer_base_model=None, | |
| gpu_id=None, n_jobs=None, local_files_only=None, resume_download=None, use_auth_token=None, | |
| trust_remote_code=None, offload_folder=None, rope_scaling=None, max_seq_len=None, compile_model=None, | |
| llamacpp_dict=None, | |
| # for some evaluate args | |
| stream_output=None, async_output=None, num_async=None, | |
| prompt_type=None, prompt_dict=None, system_prompt=None, | |
| temperature=None, top_p=None, top_k=None, num_beams=None, | |
| max_new_tokens=None, min_new_tokens=None, early_stopping=None, max_time=None, repetition_penalty=None, | |
| num_return_sequences=None, do_sample=None, chat=None, | |
| langchain_mode=None, langchain_action=None, langchain_agents=None, | |
| document_subset=None, document_choice=None, | |
| top_k_docs=None, chunk=None, chunk_size=None, | |
| pre_prompt_query=None, prompt_query=None, | |
| pre_prompt_summary=None, prompt_summary=None, | |
| image_loaders=None, | |
| pdf_loaders=None, | |
| url_loaders=None, | |
| jq_schema=None, | |
| visible_models=None, | |
| h2ogpt_key=None, | |
| add_search_to_context=None, | |
| chat_conversation=None, | |
| text_context_list=None, | |
| docs_ordering_type=None, | |
| min_max_new_tokens=None, | |
| # for evaluate kwargs | |
| captions_model=None, | |
| caption_loader=None, | |
| doctr_loader=None, | |
| pix2struct_loader=None, | |
| image_loaders_options0=None, | |
| pdf_loaders_options0=None, | |
| url_loaders_options0=None, | |
| jq_schema0=None, | |
| keep_sources_in_context=None, | |
| src_lang=None, tgt_lang=None, concurrency_count=None, save_dir=None, sanitize_bot_response=None, | |
| model_state0=None, | |
| max_max_new_tokens=None, | |
| is_public=None, | |
| max_max_time=None, | |
| raise_generate_gpu_exceptions=None, load_db_if_exists=None, use_llm_if_no_docs=None, | |
| my_db_state0=None, selection_docs_state0=None, dbs=None, langchain_modes=None, langchain_mode_paths=None, | |
| detect_user_path_changes_every_query=None, | |
| use_openai_embedding=None, use_openai_model=None, | |
| hf_embedding_model=None, migrate_embedding_model=None, auto_migrate_db=None, | |
| cut_distance=None, | |
| answer_with_sources=None, | |
| append_sources_to_answer=None, | |
| show_accordions=None, | |
| top_k_docs_max_show=None, | |
| show_link_in_sources=None, | |
| add_chat_history_to_context=None, | |
| context=None, iinput=None, | |
| db_type=None, first_para=None, text_limit=None, verbose=None, cli=None, | |
| use_cache=None, | |
| auto_reduce_chunks=None, max_chunks=None, headsize=None, | |
| model_lock=None, force_langchain_evaluate=None, | |
| model_state_none=None, | |
| # unique to this function: | |
| cli_loop=None, | |
| ): | |
| # avoid noisy command line outputs | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import logging | |
| logging.getLogger("torch").setLevel(logging.ERROR) | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| check_locals(**locals()) | |
| score_model = "" # FIXME: For now, so user doesn't have to pass | |
| n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 | |
| device = 'cpu' if n_gpus == 0 else 'cuda' | |
| context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device | |
| with context_class(device): | |
| from functools import partial | |
| # get score model | |
| smodel, stokenizer, sdevice = get_score_model(reward_type=True, | |
| **get_kwargs(get_score_model, exclude_names=['reward_type'], | |
| **locals())) | |
| model, tokenizer, device = get_model(reward_type=False, | |
| **get_kwargs(get_model, exclude_names=['reward_type'], **locals())) | |
| model_dict = dict(base_model=base_model, tokenizer_base_model=tokenizer_base_model, lora_weights=lora_weights, | |
| inference_server=inference_server, prompt_type=prompt_type, prompt_dict=prompt_dict, | |
| visible_models=None, h2ogpt_key=None) | |
| model_state = dict(model=model, tokenizer=tokenizer, device=device) | |
| model_state.update(model_dict) | |
| requests_state0 = {} | |
| fun = partial(evaluate, model_state, my_db_state0, selection_docs_state0, requests_state0, | |
| **get_kwargs(evaluate, exclude_names=['model_state', | |
| 'my_db_state', | |
| 'selection_docs_state', | |
| 'requests_state'] + eval_func_param_names, | |
| **locals())) | |
| example1 = examples[-1] # pick reference example | |
| all_generations = [] | |
| if not context: | |
| context = '' | |
| while True: | |
| clear_torch_cache() | |
| instruction = input("\nEnter an instruction: ") | |
| if instruction == "exit": | |
| break | |
| eval_vars = copy.deepcopy(example1) | |
| eval_vars[eval_func_param_names.index('instruction')] = \ | |
| eval_vars[eval_func_param_names.index('instruction_nochat')] = instruction | |
| eval_vars[eval_func_param_names.index('iinput')] = \ | |
| eval_vars[eval_func_param_names.index('iinput_nochat')] = iinput | |
| eval_vars[eval_func_param_names.index('context')] = context | |
| # grab other parameters, like langchain_mode | |
| for k in eval_func_param_names: | |
| if k in locals(): | |
| eval_vars[eval_func_param_names.index(k)] = locals()[k] | |
| gener = fun(*tuple(eval_vars)) | |
| outr = '' | |
| res_old = '' | |
| for gen_output in gener: | |
| res = gen_output['response'] | |
| extra = gen_output['sources'] | |
| if base_model not in non_hf_types or base_model in ['llama']: | |
| if not stream_output: | |
| print(res) | |
| else: | |
| # then stream output for gradio that has full output each generation, so need here to show only new chars | |
| diff = res[len(res_old):] | |
| print(diff, end='', flush=True) | |
| res_old = res | |
| outr = res # don't accumulate | |
| else: | |
| outr += res # just is one thing | |
| if extra: | |
| # show sources at end after model itself had streamed to std rest of response | |
| print('\n\n' + extra, flush=True) | |
| all_generations.append(outr + '\n') | |
| if not cli_loop: | |
| break | |
| return all_generations | |