# app.py import os import torch import spaces import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" PEFT_MODEL_ID = "befm/Be.FM-8B" # Use /data for persistent storage to avoid re-downloading models CACHE_DIR = "/data" if os.path.exists("/data") else None USE_PEFT = True try: from peft import PeftModel, PeftConfig # noqa except Exception: USE_PEFT = False print("[WARN] 'peft' not installed; running base model only.") def load_model_and_tokenizer(): if HF_TOKEN is None: raise RuntimeError( "HF_TOKEN is not set. Add it in Space → Settings → Secrets. " "Also ensure your account has access to the gated base model." ) dtype = torch.float16 if torch.cuda.is_available() else torch.float32 tok = AutoTokenizer.from_pretrained( BASE_MODEL_ID, token=HF_TOKEN, cache_dir=CACHE_DIR # Use persistent storage ) if tok.pad_token is None: tok.pad_token = tok.eos_token base = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, device_map="auto" if torch.cuda.is_available() else None, torch_dtype=dtype, token=HF_TOKEN, cache_dir=CACHE_DIR # Use persistent storage ) print(f"[INFO] Using cache directory: {CACHE_DIR}") if USE_PEFT: try: _ = PeftConfig.from_pretrained( PEFT_MODEL_ID, token=HF_TOKEN, cache_dir=CACHE_DIR # Use persistent storage ) model = PeftModel.from_pretrained( base, PEFT_MODEL_ID, token=HF_TOKEN, cache_dir=CACHE_DIR # Use persistent storage ) print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}") return model, tok except Exception as e: print(f"[WARN] Failed to load PEFT adapter: {e}") return base, tok return base, tok # Lazy load model and tokenizer _model = None _tokenizer = None def get_model_and_tokenizer(): global _model, _tokenizer if _model is None: _model, _tokenizer = load_model_and_tokenizer() return _model, _tokenizer @spaces.GPU @torch.inference_mode() def generate_response(messages, max_new_tokens=512, temperature=0.7) -> str: model, tokenizer = get_model_and_tokenizer() device = model.device # Apply Llama 3.1 chat template prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) enc = {k: v.to(device) for k, v in enc.items()} input_length = enc['input_ids'].shape[1] out = model.generate( **enc, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=0.9, pad_token_id=tokenizer.eos_token_id, ) # Decode only the newly generated tokens generated_text = tokenizer.decode(out[0][input_length:], skip_special_tokens=True) return generated_text.strip() def chat_fn(message, history, system_prompt, _prompt_reference, max_new_tokens, temperature): # Build conversation in Llama 3.1 chat format messages = [] # Add system prompt (use default if not provided) if not system_prompt: system_prompt = ( "Your are a Be.FM assistant. Be.FM is a family of open foundation models " "designed for human behavior modeling. Built on Llama 3.1 and fine-tuned on " "diverse behavioral datasets, Be.FM models are designed to enhance the " "understanding and prediction of human decision-making." ) messages.append({"role": "system", "content": system_prompt}) # Handle Gradio 6.0 history format # History format: [{"role": "user", "content": [{"type": "text", "text": "..."}]}, ...] for msg in (history or []): role = msg.get("role", "user") content = msg.get("content", "") # Extract text from structured content if isinstance(content, list): # Gradio 6.0 format: content is a list of dicts text_parts = [c.get("text", "") for c in content if c.get("type") == "text"] content = " ".join(text_parts) if content: messages.append({"role": role, "content": content}) if message: # Handle message (could be string or dict in Gradio 6.0) if isinstance(message, dict): text = message.get("text", "") else: text = message if text: messages.append({"role": "user", "content": text}) reply = generate_response( messages, max_new_tokens=max_new_tokens, temperature=temperature, ) return reply demo = gr.ChatInterface( fn=chat_fn, chatbot=gr.Chatbot( label="Chat with BeFM", show_label=True, avatar_images=(None, None), # Use default avatars or provide custom image paths ), additional_inputs=[ gr.Textbox( label="System prompt (optional)", placeholder=( "Your are a Be.FM assistant. Be.FM is a family of open foundation models " "designed for human behavior modeling. Built on Llama 3.1 and fine-" "tuned on diverse behavioral datasets, Be.FM models are designed to " "enhance the understanding and prediction of human decision-making." ), lines=2, ), gr.Markdown( "For system and user prompts in a variety of behavioral tasks, please refer " "to the appendix in our [paper](https://arxiv.org/abs/2505.23058)." ), gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"), gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"), ], title="Be.FM: Open Foundation Models for Human Behavior (8B)", ) if __name__ == "__main__": demo.launch(share=True)