Jn-Huang commited on
Commit
f6fde6f
·
1 Parent(s): 8e51924

Add Be.FM-8B chat interface with PEFT adapter

Browse files
Files changed (2) hide show
  1. app.py +105 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import torch
4
+ import spaces
5
+ import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
9
+
10
+ BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
11
+ PEFT_MODEL_ID = "befm/Be.FM-8B"
12
+
13
+ USE_PEFT = True
14
+ try:
15
+ from peft import PeftModel, PeftConfig # noqa
16
+ except Exception:
17
+ USE_PEFT = False
18
+ print("[WARN] 'peft' not installed; running base model only.")
19
+
20
+ def load_model_and_tokenizer():
21
+ if HF_TOKEN is None:
22
+ raise RuntimeError(
23
+ "HF_TOKEN is not set. Add it in Space → Settings → Secrets. "
24
+ "Also ensure your account has access to the gated base model."
25
+ )
26
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
27
+ tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
28
+ if tok.pad_token is None:
29
+ tok.pad_token = tok.eos_token
30
+
31
+ base = AutoModelForCausalLM.from_pretrained(
32
+ BASE_MODEL_ID,
33
+ device_map="auto" if torch.cuda.is_available() else None,
34
+ torch_dtype=dtype,
35
+ token=HF_TOKEN,
36
+ )
37
+
38
+ if USE_PEFT:
39
+ try:
40
+ _ = PeftConfig.from_pretrained(PEFT_MODEL_ID, use_auth_token=HF_TOKEN)
41
+ model = PeftModel.from_pretrained(base, PEFT_MODEL_ID, use_auth_token=HF_TOKEN)
42
+ print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
43
+ return model, tok
44
+ except Exception as e:
45
+ print(f"[WARN] Failed to load PEFT adapter: {e}")
46
+ return base, tok
47
+ return base, tok
48
+
49
+ model, tokenizer = load_model_and_tokenizer()
50
+ DEVICE = model.device
51
+
52
+ @spaces.GPU
53
+ @torch.inference_mode()
54
+ def generate_response(prompt: str, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
55
+ enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
56
+ enc = {k: v.to(DEVICE) for k, v in enc.items()}
57
+ out = model.generate(
58
+ **enc,
59
+ max_new_tokens=max_new_tokens,
60
+ do_sample=True,
61
+ temperature=temperature,
62
+ top_p=top_p,
63
+ pad_token_id=tokenizer.eos_token_id,
64
+ )
65
+ return tokenizer.decode(out[0], skip_special_tokens=True)
66
+
67
+ def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
68
+ # Build a simple conversation string
69
+ conv = []
70
+ if system_prompt:
71
+ conv.append(f"system: {system_prompt}")
72
+ for u, a in (history or []):
73
+ if u:
74
+ conv.append(f"user: {u}")
75
+ if a:
76
+ conv.append(f"assistant: {a}")
77
+ if message:
78
+ conv.append(f"user: {message}")
79
+ prompt = "\n".join(conv) + "\nassistant:"
80
+ reply = generate_response(
81
+ prompt,
82
+ max_new_tokens=max_new_tokens,
83
+ temperature=temperature,
84
+ top_p=top_p,
85
+ )
86
+ # Strip trailing
87
+ if "assistant:" in reply:
88
+ reply = reply.split("assistant:")[-1].strip()
89
+ return reply
90
+
91
+ demo = gr.ChatInterface(
92
+ fn=lambda message, history, system_prompt, max_new_tokens, temperature, top_p:
93
+ chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p),
94
+ additional_inputs=[
95
+ gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2),
96
+ gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
97
+ gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
98
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
99
+ ],
100
+ title="Be.FM-8B (PEFT) on Meta-Llama-3.1-8B-Instruct",
101
+ description="Chat interface using Meta-Llama-3.1-8B-Instruct with PEFT adapter befm/Be.FM-8B."
102
+ )
103
+
104
+ if __name__ == "__main__":
105
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ peft>=0.4.0
4
+ spaces
5
+ accelerate