Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from pprint import pprint | |
| import bitsandbytes as bnb | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| from datasets import load_dataset | |
| from huggingface_hub import notebook_login | |
| from peft import ( | |
| LoraConfig, | |
| PeftConfig, | |
| PeftModel, | |
| get_peft_model, | |
| prepare_model_for_kbit_training, | |
| ) | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| ) | |
| import gradio as gr | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| PEFT_MODEL = "cdy3870/Falcon-Fetch-Bot" | |
| config = PeftConfig.from_pretrained(PEFT_MODEL) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.base_model_name_or_path, | |
| return_dict=True, | |
| device_map="auto", | |
| trust_remote_code=True, load_in_8bit=False, offload_folder="offload" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = PeftModel.from_pretrained(model, PEFT_MODEL) | |
| generation_config = model.generation_config | |
| generation_config.max_new_tokens = 150 | |
| generation_config.temperature = 0.6 | |
| generation_config.top_p = 0.7 | |
| generation_config.num_return_sequences = 1 | |
| generation_config.pad_token_id = tokenizer.eos_token_id | |
| generation_config.eos_token_id = tokenizer.eos_token_id | |
| pipeline = transformers.pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| ) | |
| def main(): | |
| with gr.Blocks() as demo: | |
| def update_temp(temp): | |
| generation_config.temperature = temp | |
| def update_tokens(tokens): | |
| generation_config.max_new_tokens = tokens | |
| chatbot = gr.Chatbot(label="Fetch Rewards Chatbot") | |
| temperature = gr.Slider(0, 1, value=0.6, step=0.1, label="Creativity", interactive=True) | |
| temperature.change(fn=update_temp, inputs=temperature) | |
| tokens = gr.Slider(50, 200, value=100, step=50, label="Length", interactive=True) | |
| tokens.change(fn=update_tokens, inputs=tokens) | |
| msg = gr.Textbox(label="", placeholder="Ask anything about Fetch!") | |
| clear = gr.Button("Clear Log") | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| def bot(history): | |
| message = history[-1][0] | |
| prompt = f""" | |
| <human>: {message} | |
| <assistant>: | |
| """.strip() | |
| result = pipeline( | |
| prompt, | |
| generation_config=generation_config, | |
| ) | |
| # print(result) | |
| parsed_result = result[0]["generated_text"].split("<assistant>:")[1][1:].split("\n")[0] | |
| history[-1][1] = "" | |
| for character in parsed_result: | |
| history[-1][1] += character | |
| time.sleep(0.01) | |
| yield history | |
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| bot, chatbot, chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.queue() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |