Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
| PEFT_MODEL_ID = "befm/Be.FM-8B" | |
| # Use /data for persistent storage to avoid re-downloading models | |
| CACHE_DIR = "/data" if os.path.exists("/data") else None | |
| USE_PEFT = True | |
| try: | |
| from peft import PeftModel, PeftConfig # noqa | |
| except Exception: | |
| USE_PEFT = False | |
| print("[WARN] 'peft' not installed; running base model only.") | |
| def load_model_and_tokenizer(): | |
| if HF_TOKEN is None: | |
| raise RuntimeError( | |
| "HF_TOKEN is not set. Add it in Space → Settings → Secrets. " | |
| "Also ensure your account has access to the gated base model." | |
| ) | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| tok = AutoTokenizer.from_pretrained( | |
| BASE_MODEL_ID, | |
| token=HF_TOKEN, | |
| cache_dir=CACHE_DIR # Use persistent storage | |
| ) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| torch_dtype=dtype, | |
| token=HF_TOKEN, | |
| cache_dir=CACHE_DIR # Use persistent storage | |
| ) | |
| print(f"[INFO] Using cache directory: {CACHE_DIR}") | |
| if USE_PEFT: | |
| try: | |
| _ = PeftConfig.from_pretrained( | |
| PEFT_MODEL_ID, | |
| token=HF_TOKEN, | |
| cache_dir=CACHE_DIR # Use persistent storage | |
| ) | |
| model = PeftModel.from_pretrained( | |
| base, | |
| PEFT_MODEL_ID, | |
| token=HF_TOKEN, | |
| cache_dir=CACHE_DIR # Use persistent storage | |
| ) | |
| print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}") | |
| return model, tok | |
| except Exception as e: | |
| print(f"[WARN] Failed to load PEFT adapter: {e}") | |
| return base, tok | |
| return base, tok | |
| # Lazy load model and tokenizer | |
| _model = None | |
| _tokenizer = None | |
| def get_model_and_tokenizer(): | |
| global _model, _tokenizer | |
| if _model is None: | |
| _model, _tokenizer = load_model_and_tokenizer() | |
| return _model, _tokenizer | |
| def generate_response(messages, max_new_tokens=512, temperature=0.7) -> str: | |
| model, tokenizer = get_model_and_tokenizer() | |
| device = model.device | |
| # Apply Llama 3.1 chat template | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| input_length = enc['input_ids'].shape[1] | |
| out = model.generate( | |
| **enc, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode only the newly generated tokens | |
| generated_text = tokenizer.decode(out[0][input_length:], skip_special_tokens=True) | |
| return generated_text.strip() | |
| def chat_fn(message, history, system_prompt, _prompt_reference, max_new_tokens, temperature): | |
| # Build conversation in Llama 3.1 chat format | |
| messages = [] | |
| # Add system prompt (use default if not provided) | |
| if not system_prompt: | |
| system_prompt = ( | |
| "Your are a Be.FM assistant. Be.FM is a family of open foundation models " | |
| "designed for human behavior modeling. Built on Llama 3.1 and fine-tuned on " | |
| "diverse behavioral datasets, Be.FM models are designed to enhance the " | |
| "understanding and prediction of human decision-making." | |
| ) | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Handle Gradio 6.0 history format | |
| # History format: [{"role": "user", "content": [{"type": "text", "text": "..."}]}, ...] | |
| for msg in (history or []): | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| # Extract text from structured content | |
| if isinstance(content, list): | |
| # Gradio 6.0 format: content is a list of dicts | |
| text_parts = [c.get("text", "") for c in content if c.get("type") == "text"] | |
| content = " ".join(text_parts) | |
| if content: | |
| messages.append({"role": role, "content": content}) | |
| if message: | |
| # Handle message (could be string or dict in Gradio 6.0) | |
| if isinstance(message, dict): | |
| text = message.get("text", "") | |
| else: | |
| text = message | |
| if text: | |
| messages.append({"role": "user", "content": text}) | |
| reply = generate_response( | |
| messages, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| ) | |
| return reply | |
| demo = gr.ChatInterface( | |
| fn=chat_fn, | |
| chatbot=gr.Chatbot( | |
| label="Chat with BeFM", | |
| show_label=True, | |
| avatar_images=(None, None), # Use default avatars or provide custom image paths | |
| ), | |
| additional_inputs=[ | |
| gr.Textbox( | |
| label="System prompt (optional)", | |
| placeholder=( | |
| "Your are a Be.FM assistant. Be.FM is a family of open foundation models " | |
| "designed for human behavior modeling. Built on Llama 3.1 and fine-" | |
| "tuned on diverse behavioral datasets, Be.FM models are designed to " | |
| "enhance the understanding and prediction of human decision-making." | |
| ), | |
| lines=2, | |
| ), | |
| gr.Markdown( | |
| "For system and user prompts in a variety of behavioral tasks, please refer " | |
| "to the appendix in our [paper](https://arxiv.org/abs/2505.23058)." | |
| ), | |
| gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"), | |
| gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"), | |
| ], | |
| title="Be.FM: Open Foundation Models for Human Behavior (8B)", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |