| import gradio as gr |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer |
| ) |
| from peft import PeftModel |
|
|
| model_name = "tiiuae/falcon-7b" |
| model_id = "personachat-finetuned-3000-steps" |
| template = open("template.txt", "r").read() |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| trust_remote_code = True |
| ) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| device_map = "auto", |
| load_in_8bit = True, |
| trust_remote_code = True, |
| low_cpu_mem_usage = True |
| ) |
| tuned_model = PeftModel.from_pretrained( |
| base_model, |
| model_id |
| ) |
|
|
| def parse_response(encoded_output, user_input): |
| decoded_output = tokenizer.batch_decode(encoded_output)[0] |
| decoded_output = decoded_output.replace(user_input, "") |
| decoded_output = decoded_output.split("<|endoftext|>",1)[0].strip() |
| return decoded_output |
|
|
| def generate(personality, user_input, state = {"base_state":[], "tune_state":[]}): |
| try: |
| personality = "\n".join(personality.split(".")) |
| except: pass |
| state["base_state"].append(user_input) |
| state["tune_state"].append(user_input) |
| base_prompt = template.format( |
| personality = personality, |
| history = "\n".join(state["base_state"]) |
| ) |
| tune_prompt = template.format( |
| personality = personality, |
| history = "\n".join(state["tune_state"]) |
| ) |
| print("****************************") |
| print(base_prompt) |
| print("****************************") |
| print(tune_prompt) |
| print("****************************") |
| base_input_ids = tokenizer(base_prompt, return_tensors="pt").to("cuda") |
| tune_input_ids = tokenizer(tune_prompt, return_tensors="pt").to("cuda") |
| kwargs = dict({ |
| "top_k": 0, |
| "top_p": 0.9, |
| "do_sample": True, |
| "temperature": 0.5, |
| "max_new_tokens": 50, |
| "repetition_penalty": 1.1, |
| "num_return_sequences": 1 |
| }) |
| base_model_response = parse_response( |
| base_model.generate( |
| input_ids = base_input_ids["input_ids"], |
| **kwargs |
| ), |
| base_prompt |
| ) |
| tune_model_response = parse_response( |
| tuned_model.generate( |
| input_ids = tune_input_ids["input_ids"], |
| **kwargs |
| ), |
| tune_prompt |
| ) |
| state["base_state"].append(base_model_response) |
| state["tune_state"].append(tune_model_response) |
| return base_model_response, tune_model_response, state |
|
|
| gr.Interface( |
| fn = generate, |
| inputs = [ |
| gr.Textbox(label = "user personality", place_holder = "Enter your personality"), |
| gr.Textbox(label = "user chat", place_holder = "Enter your message"), |
| "state" |
| ], |
| outputs = [ |
| gr.Textbox(label = "base model response"), |
| gr.Textbox(label = "fine tuned model response"), |
| "state" |
| ], |
| theme = "gradio/seafoam" |
| ).launch(share = True) |