| import spaces |
| import gradio as gr |
| import torch |
| from transformers import pipeline, BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, AutoConfig |
| from peft import PeftModel |
|
|
| MODEL_ID = "unsloth/Meta-Llama-3.1-70B-bnb-4bit" |
| ADAPTER_ID = "marcelbinz/Llama-3.1-Centaur-70B-adapter" |
|
|
| cfg = AutoConfig.from_pretrained(MODEL_ID) |
| cfg.rope_scaling = { |
| "type": "yarn", |
| "factor": 4.0, |
| "original_max_position_embeddings": 8192, |
| } |
| cfg.max_position_embeddings = 32768 |
|
|
| bnb_4bit_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| model_base = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| device_map="auto", |
| attn_implementation="flash_attention_2", |
| config=cfg, |
| quantization_config=bnb_4bit_config, |
| ) |
|
|
| model = PeftModel.from_pretrained(model_base, ADAPTER_ID, device_map="auto") |
|
|
| pipe = pipeline( |
| "text-generation", |
| model=model, |
| tokenizer=tokenizer, |
| device_map="auto", |
| ) |
|
|
| @spaces.GPU |
| def infer(prompt): |
| return pipe(prompt, max_new_tokens=1, do_sample=True, temperature=1.0, return_full_text=True)[0]["generated_text"] |
|
|
| default_experiment = """You will be presented with triplets of objects, which will be assigned to the keys H, Y, and E. |
| In each trial, please indicate which object you think is the odd one out by pressing the corresponding key. |
| In other words, please choose the object that is the least similar to the other two. |
| |
| H: plant, Y: chainsaw, and E: periscope. You press <<H>>. |
| H: tostada, Y: leaf, and E: sail. You press <<H>>. |
| H: clock, Y: crystal, and E: grate. You press <<Y>>. |
| H: barbed wire, Y: kale, and E: sweater. You press <<E>>. |
| H: raccoon, Y: toothbrush, and E: ice. You press <<""" |
|
|
| with gr.Blocks( |
| fill_width=True, |
| css=""" |
| #prompt-box textarea {height:256px} |
| """, |
| ) as demo: |
| gr.Image( |
| value="https://marcelbinz.github.io/imgs/centaur.png", |
| show_label=False, |
| height=180, |
| container=False, |
| elem_classes="mx-auto", |
| ) |
| |
| gr.Markdown( |
| """ |
| ### How to prompt: |
| - We did not employ a particular prompt template – just phrase everything in natural language. |
| - Human choices are encapsulated by "<<" and ">>" tokens. |
| - Most experiments in the training data are framed in terms of button presses. If possible, it is recommended to use that style. |
| - You can find examples in the Supporting Information of our paper. |
| """, |
| elem_id="info-box", |
| ) |
|
|
| inp = gr.Textbox( |
| label="Prompt", |
| elem_id="prompt-box", |
| lines=16, |
| max_lines=16, |
| scale=3, |
| value=default_experiment, |
| ) |
|
|
| run = gr.Button("Run") |
| run.click(infer, inp, inp) |
|
|
| demo.queue().launch() |