|
|
|
|
|
import os |
|
|
import spaces |
|
|
import gradio as gr |
|
|
from vllm import LLM, SamplingParams |
|
|
from vllm.lora.request import LoRARequest |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
|
|
BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
PEFT_MODEL_ID = "befm/Be.FM-8B" |
|
|
|
|
|
|
|
|
CACHE_DIR = "/data" if os.path.exists("/data") else None |
|
|
|
|
|
def load_model(): |
|
|
if HF_TOKEN is None: |
|
|
raise RuntimeError( |
|
|
"HF_TOKEN is not set. Add it in Space → Settings → Secrets. " |
|
|
"Also ensure your account has access to the gated base model." |
|
|
) |
|
|
|
|
|
|
|
|
llm = LLM( |
|
|
model=BASE_MODEL_ID, |
|
|
tokenizer=BASE_MODEL_ID, |
|
|
enable_lora=True, |
|
|
max_lora_rank=64, |
|
|
dtype="float16", |
|
|
gpu_memory_utilization=0.7, |
|
|
trust_remote_code=True, |
|
|
download_dir=CACHE_DIR, |
|
|
) |
|
|
|
|
|
print(f"[INFO] vLLM loaded base model: {BASE_MODEL_ID}") |
|
|
print(f"[INFO] Using cache directory: {CACHE_DIR}") |
|
|
|
|
|
|
|
|
lora_request = LoRARequest( |
|
|
lora_name="befm", |
|
|
lora_int_id=1, |
|
|
lora_path=PEFT_MODEL_ID, |
|
|
) |
|
|
print(f"[INFO] PEFT adapter prepared: {PEFT_MODEL_ID}") |
|
|
|
|
|
return llm, lora_request |
|
|
|
|
|
|
|
|
_llm = None |
|
|
_lora_request = None |
|
|
_tokenizer = None |
|
|
|
|
|
def get_model_and_tokenizer(): |
|
|
global _llm, _lora_request, _tokenizer |
|
|
if _llm is None: |
|
|
_llm, _lora_request = load_model() |
|
|
_tokenizer = AutoTokenizer.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
token=HF_TOKEN, |
|
|
cache_dir=CACHE_DIR |
|
|
) |
|
|
return _llm, _lora_request, _tokenizer |
|
|
|
|
|
@spaces.GPU |
|
|
def generate_response(messages, max_new_tokens=512, temperature=0.7) -> str: |
|
|
llm, lora_request, tokenizer = get_model_and_tokenizer() |
|
|
|
|
|
|
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
sampling_params = SamplingParams( |
|
|
temperature=temperature, |
|
|
top_p=0.9, |
|
|
max_tokens=max_new_tokens, |
|
|
) |
|
|
|
|
|
|
|
|
outputs = llm.generate( |
|
|
prompts=[prompt], |
|
|
sampling_params=sampling_params, |
|
|
lora_request=lora_request, |
|
|
) |
|
|
|
|
|
return outputs[0].outputs[0].text |
|
|
|
|
|
def chat_fn(message, history, system_prompt, _prompt_reference, max_new_tokens, temperature): |
|
|
|
|
|
messages = [] |
|
|
|
|
|
|
|
|
if not system_prompt: |
|
|
system_prompt = ( |
|
|
"Your are a Be.FM assistant. Be.FM is a family of open foundation models " |
|
|
"designed for human behavior modeling. Built on Llama 3.1 and fine-tuned on " |
|
|
"diverse behavioral datasets, Be.FM models are designed to enhance the " |
|
|
"understanding and prediction of human decision-making." |
|
|
) |
|
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
|
|
|
for msg in (history or []): |
|
|
messages.append(msg) |
|
|
|
|
|
if message: |
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
reply = generate_response( |
|
|
messages, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
) |
|
|
return reply |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=lambda message, history, system_prompt, prompt_reference, max_new_tokens, temperature: |
|
|
chat_fn(message, history, system_prompt, prompt_reference, max_new_tokens, temperature), |
|
|
additional_inputs=[ |
|
|
gr.Textbox( |
|
|
label="System prompt (optional)", |
|
|
placeholder=( |
|
|
"Your are a Be.FM assistant. Be.FM is a family of open foundation models " |
|
|
"designed for human behavior modeling. Built on Llama 3.1 and fine-" |
|
|
"tuned on diverse behavioral datasets, Be.FM models are designed to " |
|
|
"enhance the understanding and prediction of human decision-making." |
|
|
), |
|
|
lines=2, |
|
|
), |
|
|
gr.Markdown( |
|
|
"For system and user prompts in a variety of behavioral tasks, please refer " |
|
|
"to the appendix in our [paper](https://arxiv.org/abs/2505.23058)." |
|
|
), |
|
|
gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"), |
|
|
gr.Slider(0.1, 1.5, value=0.6, step=0.05, label="temperature"), |
|
|
], |
|
|
title="Be.FM: Open Foundation Models for Human Behavior (8B)", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|