File size: 6,163 Bytes
fc3b3a2
f6fde6f
fc3b3a2
f6fde6f
 
fc3b3a2
f6fde6f
 
 
 
 
 
fc3b3a2
 
 
 
 
 
 
 
 
 
 
f6fde6f
 
 
 
 
fc3b3a2
 
 
 
 
89babab
fc3b3a2
 
 
 
 
 
 
 
 
f6fde6f
 
fc3b3a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cc1531
 
 
 
 
 
 
 
 
f6fde6f
 
fc3b3a2
58ffc70
4cc1531
 
 
1a77428
 
 
 
 
 
fc3b3a2
4cc1531
1a77428
fc3b3a2
 
 
 
 
f6fde6f
58ffc70
fc3b3a2
f6fde6f
fc3b3a2
4cc1531
 
f6fde6f
d298fc0
1a77428
 
0600d50
 
 
58ffc70
1002aec
 
 
d8ad02c
58ffc70
0600d50
1a77428
38dedc7
 
eaaeae1
38dedc7
 
 
 
 
 
 
 
 
 
 
1a77428
f6fde6f
38dedc7
 
 
 
 
 
 
 
1a77428
f6fde6f
1a77428
f6fde6f
 
 
 
 
 
4cc1531
3db6d3c
 
 
 
 
f6fde6f
58ffc70
 
 
1002aec
 
 
d8ad02c
58ffc70
 
 
 
97c12a9
 
58ffc70
f6fde6f
49689a5
f6fde6f
58ffc70
f6fde6f
 
 
4cc1531
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# app.py
import os
import torch
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")

BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
PEFT_MODEL_ID = "befm/Be.FM-8B"

# Use /data for persistent storage to avoid re-downloading models
CACHE_DIR = "/data" if os.path.exists("/data") else None

USE_PEFT = True
try:
    from peft import PeftModel, PeftConfig  # noqa
except Exception:
    USE_PEFT = False
    print("[WARN] 'peft' not installed; running base model only.")

def load_model_and_tokenizer():
    if HF_TOKEN is None:
        raise RuntimeError(
            "HF_TOKEN is not set. Add it in Space → Settings → Secrets. "
            "Also ensure your account has access to the gated base model."
        )
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    tok = AutoTokenizer.from_pretrained(
        BASE_MODEL_ID,
        token=HF_TOKEN,
        cache_dir=CACHE_DIR  # Use persistent storage
    )
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        device_map="auto" if torch.cuda.is_available() else None,
        torch_dtype=dtype,
        token=HF_TOKEN,
        cache_dir=CACHE_DIR  # Use persistent storage
    )

    print(f"[INFO] Using cache directory: {CACHE_DIR}")

    if USE_PEFT:
        try:
            _ = PeftConfig.from_pretrained(
                PEFT_MODEL_ID,
                token=HF_TOKEN,
                cache_dir=CACHE_DIR  # Use persistent storage
            )
            model = PeftModel.from_pretrained(
                base,
                PEFT_MODEL_ID,
                token=HF_TOKEN,
                cache_dir=CACHE_DIR  # Use persistent storage
            )
            print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
            return model, tok
        except Exception as e:
            print(f"[WARN] Failed to load PEFT adapter: {e}")
            return base, tok
    return base, tok

# Lazy load model and tokenizer
_model = None
_tokenizer = None

def get_model_and_tokenizer():
    global _model, _tokenizer
    if _model is None:
        _model, _tokenizer = load_model_and_tokenizer()
    return _model, _tokenizer

@spaces.GPU
@torch.inference_mode()
def generate_response(messages, max_new_tokens=512, temperature=0.7) -> str:
    model, tokenizer = get_model_and_tokenizer()
    device = model.device

    # Apply Llama 3.1 chat template
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    enc = {k: v.to(device) for k, v in enc.items()}

    input_length = enc['input_ids'].shape[1]
    out = model.generate(
        **enc,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id,
    )
    # Decode only the newly generated tokens
    generated_text = tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
    return generated_text.strip()

def chat_fn(message, history, system_prompt, _prompt_reference, max_new_tokens, temperature):
    # Build conversation in Llama 3.1 chat format
    messages = []

    # Add system prompt (use default if not provided)
    if not system_prompt:
        system_prompt = (
            "Your are a Be.FM assistant. Be.FM is a family of open foundation models "
            "designed for human behavior modeling. Built on Llama 3.1 and fine-tuned on "
            "diverse behavioral datasets, Be.FM models are designed to enhance the "
            "understanding and prediction of human decision-making."
        )
    messages.append({"role": "system", "content": system_prompt})

    # Handle Gradio 6.0 history format
    # History format: [{"role": "user", "content": [{"type": "text", "text": "..."}]}, ...]
    for msg in (history or []):
        role = msg.get("role", "user")
        content = msg.get("content", "")

        # Extract text from structured content
        if isinstance(content, list):
            # Gradio 6.0 format: content is a list of dicts
            text_parts = [c.get("text", "") for c in content if c.get("type") == "text"]
            content = " ".join(text_parts)

        if content:
            messages.append({"role": role, "content": content})

    if message:
        # Handle message (could be string or dict in Gradio 6.0)
        if isinstance(message, dict):
            text = message.get("text", "")
        else:
            text = message

        if text:
            messages.append({"role": "user", "content": text})

    reply = generate_response(
        messages,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
    )
    return reply

demo = gr.ChatInterface(
    fn=chat_fn,
    chatbot=gr.Chatbot(
        label="Chat with BeFM",
        show_label=True,
        avatar_images=(None, None),  # Use default avatars or provide custom image paths
    ),
    additional_inputs=[
        gr.Textbox(
            label="System prompt (optional)",
            placeholder=(
                "Your are a Be.FM assistant. Be.FM is a family of open foundation models "
                "designed for human behavior modeling. Built on Llama 3.1 and fine-"
                "tuned on diverse behavioral datasets, Be.FM models are designed to "
                "enhance the understanding and prediction of human decision-making."
            ),
            lines=2,
        ),
        gr.Markdown(
            "For system and user prompts in a variety of behavioral tasks, please refer "
            "to the appendix in our [paper](https://arxiv.org/abs/2505.23058)."
        ),
        gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
        gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
    ],
    title="Be.FM: Open Foundation Models for Human Behavior (8B)",
)

if __name__ == "__main__":
    demo.launch(share=True)