BeFM / app_vllm.py
Jn-Huang
Update prompt reference text and space emoji
97c12a9
# app_vllm.py - Faster inference using vLLM
import os
import spaces
import gradio as gr
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoTokenizer
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
PEFT_MODEL_ID = "befm/Be.FM-8B"
# Use /data for persistent storage to avoid re-downloading models
CACHE_DIR = "/data" if os.path.exists("/data") else None
def load_model():
if HF_TOKEN is None:
raise RuntimeError(
"HF_TOKEN is not set. Add it in Space → Settings → Secrets. "
"Also ensure your account has access to the gated base model."
)
# Initialize vLLM with PEFT support
llm = LLM(
model=BASE_MODEL_ID,
tokenizer=BASE_MODEL_ID,
enable_lora=True,
max_lora_rank=64,
dtype="float16",
gpu_memory_utilization=0.7, # Reduced from 0.9 to avoid OOM on T4 GPU
trust_remote_code=True,
download_dir=CACHE_DIR, # Use persistent storage
)
print(f"[INFO] vLLM loaded base model: {BASE_MODEL_ID}")
print(f"[INFO] Using cache directory: {CACHE_DIR}")
# Load PEFT adapter
lora_request = LoRARequest(
lora_name="befm",
lora_int_id=1,
lora_path=PEFT_MODEL_ID,
)
print(f"[INFO] PEFT adapter prepared: {PEFT_MODEL_ID}")
return llm, lora_request
# Lazy load model and tokenizer
_llm = None
_lora_request = None
_tokenizer = None
def get_model_and_tokenizer():
global _llm, _lora_request, _tokenizer
if _llm is None:
_llm, _lora_request = load_model()
_tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_ID,
token=HF_TOKEN,
cache_dir=CACHE_DIR # Use persistent storage
)
return _llm, _lora_request, _tokenizer
@spaces.GPU
def generate_response(messages, max_new_tokens=512, temperature=0.7) -> str:
llm, lora_request, tokenizer = get_model_and_tokenizer()
# Apply Llama 3.1 chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
sampling_params = SamplingParams(
temperature=temperature,
top_p=0.9,
max_tokens=max_new_tokens,
)
# Generate with vLLM
outputs = llm.generate(
prompts=[prompt],
sampling_params=sampling_params,
lora_request=lora_request,
)
return outputs[0].outputs[0].text
def chat_fn(message, history, system_prompt, _prompt_reference, max_new_tokens, temperature):
# Build conversation in Llama 3.1 chat format
messages = []
# Add system prompt (use default if not provided)
if not system_prompt:
system_prompt = (
"Your are a Be.FM assistant. Be.FM is a family of open foundation models "
"designed for human behavior modeling. Built on Llama 3.1 and fine-tuned on "
"diverse behavioral datasets, Be.FM models are designed to enhance the "
"understanding and prediction of human decision-making."
)
messages.append({"role": "system", "content": system_prompt})
# History is already in dict format: [{"role": "user", "content": "..."}, ...]
for msg in (history or []):
messages.append(msg)
if message:
messages.append({"role": "user", "content": message})
reply = generate_response(
messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
return reply
demo = gr.ChatInterface(
fn=lambda message, history, system_prompt, prompt_reference, max_new_tokens, temperature:
chat_fn(message, history, system_prompt, prompt_reference, max_new_tokens, temperature),
additional_inputs=[
gr.Textbox(
label="System prompt (optional)",
placeholder=(
"Your are a Be.FM assistant. Be.FM is a family of open foundation models "
"designed for human behavior modeling. Built on Llama 3.1 and fine-"
"tuned on diverse behavioral datasets, Be.FM models are designed to "
"enhance the understanding and prediction of human decision-making."
),
lines=2,
),
gr.Markdown(
"For system and user prompts in a variety of behavioral tasks, please refer "
"to the appendix in our [paper](https://arxiv.org/abs/2505.23058)."
),
gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
gr.Slider(0.1, 1.5, value=0.6, step=0.05, label="temperature"),
],
title="Be.FM: Open Foundation Models for Human Behavior (8B)",
)
if __name__ == "__main__":
demo.launch()