File size: 6,163 Bytes
fc3b3a2 f6fde6f fc3b3a2 f6fde6f fc3b3a2 f6fde6f fc3b3a2 f6fde6f fc3b3a2 89babab fc3b3a2 f6fde6f fc3b3a2 4cc1531 f6fde6f fc3b3a2 58ffc70 4cc1531 1a77428 fc3b3a2 4cc1531 1a77428 fc3b3a2 f6fde6f 58ffc70 fc3b3a2 f6fde6f fc3b3a2 4cc1531 f6fde6f d298fc0 1a77428 0600d50 58ffc70 1002aec d8ad02c 58ffc70 0600d50 1a77428 38dedc7 eaaeae1 38dedc7 1a77428 f6fde6f 38dedc7 1a77428 f6fde6f 1a77428 f6fde6f 4cc1531 3db6d3c f6fde6f 58ffc70 1002aec d8ad02c 58ffc70 97c12a9 58ffc70 f6fde6f 49689a5 f6fde6f 58ffc70 f6fde6f 4cc1531 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# app.py
import os
import torch
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
PEFT_MODEL_ID = "befm/Be.FM-8B"
# Use /data for persistent storage to avoid re-downloading models
CACHE_DIR = "/data" if os.path.exists("/data") else None
USE_PEFT = True
try:
from peft import PeftModel, PeftConfig # noqa
except Exception:
USE_PEFT = False
print("[WARN] 'peft' not installed; running base model only.")
def load_model_and_tokenizer():
if HF_TOKEN is None:
raise RuntimeError(
"HF_TOKEN is not set. Add it in Space → Settings → Secrets. "
"Also ensure your account has access to the gated base model."
)
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
tok = AutoTokenizer.from_pretrained(
BASE_MODEL_ID,
token=HF_TOKEN,
cache_dir=CACHE_DIR # Use persistent storage
)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
device_map="auto" if torch.cuda.is_available() else None,
torch_dtype=dtype,
token=HF_TOKEN,
cache_dir=CACHE_DIR # Use persistent storage
)
print(f"[INFO] Using cache directory: {CACHE_DIR}")
if USE_PEFT:
try:
_ = PeftConfig.from_pretrained(
PEFT_MODEL_ID,
token=HF_TOKEN,
cache_dir=CACHE_DIR # Use persistent storage
)
model = PeftModel.from_pretrained(
base,
PEFT_MODEL_ID,
token=HF_TOKEN,
cache_dir=CACHE_DIR # Use persistent storage
)
print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
return model, tok
except Exception as e:
print(f"[WARN] Failed to load PEFT adapter: {e}")
return base, tok
return base, tok
# Lazy load model and tokenizer
_model = None
_tokenizer = None
def get_model_and_tokenizer():
global _model, _tokenizer
if _model is None:
_model, _tokenizer = load_model_and_tokenizer()
return _model, _tokenizer
@spaces.GPU
@torch.inference_mode()
def generate_response(messages, max_new_tokens=512, temperature=0.7) -> str:
model, tokenizer = get_model_and_tokenizer()
device = model.device
# Apply Llama 3.1 chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
enc = {k: v.to(device) for k, v in enc.items()}
input_length = enc['input_ids'].shape[1]
out = model.generate(
**enc,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
# Decode only the newly generated tokens
generated_text = tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
return generated_text.strip()
def chat_fn(message, history, system_prompt, _prompt_reference, max_new_tokens, temperature):
# Build conversation in Llama 3.1 chat format
messages = []
# Add system prompt (use default if not provided)
if not system_prompt:
system_prompt = (
"Your are a Be.FM assistant. Be.FM is a family of open foundation models "
"designed for human behavior modeling. Built on Llama 3.1 and fine-tuned on "
"diverse behavioral datasets, Be.FM models are designed to enhance the "
"understanding and prediction of human decision-making."
)
messages.append({"role": "system", "content": system_prompt})
# Handle Gradio 6.0 history format
# History format: [{"role": "user", "content": [{"type": "text", "text": "..."}]}, ...]
for msg in (history or []):
role = msg.get("role", "user")
content = msg.get("content", "")
# Extract text from structured content
if isinstance(content, list):
# Gradio 6.0 format: content is a list of dicts
text_parts = [c.get("text", "") for c in content if c.get("type") == "text"]
content = " ".join(text_parts)
if content:
messages.append({"role": role, "content": content})
if message:
# Handle message (could be string or dict in Gradio 6.0)
if isinstance(message, dict):
text = message.get("text", "")
else:
text = message
if text:
messages.append({"role": "user", "content": text})
reply = generate_response(
messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
return reply
demo = gr.ChatInterface(
fn=chat_fn,
chatbot=gr.Chatbot(
label="Chat with BeFM",
show_label=True,
avatar_images=(None, None), # Use default avatars or provide custom image paths
),
additional_inputs=[
gr.Textbox(
label="System prompt (optional)",
placeholder=(
"Your are a Be.FM assistant. Be.FM is a family of open foundation models "
"designed for human behavior modeling. Built on Llama 3.1 and fine-"
"tuned on diverse behavioral datasets, Be.FM models are designed to "
"enhance the understanding and prediction of human decision-making."
),
lines=2,
),
gr.Markdown(
"For system and user prompts in a variety of behavioral tasks, please refer "
"to the appendix in our [paper](https://arxiv.org/abs/2505.23058)."
),
gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
],
title="Be.FM: Open Foundation Models for Human Behavior (8B)",
)
if __name__ == "__main__":
demo.launch(share=True)
|