File size: 3,901 Bytes
f6fde6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a77428
 
f6fde6f
 
 
 
 
 
 
 
 
 
 
 
1a77428
 
 
 
 
 
 
f6fde6f
 
1a77428
 
f6fde6f
 
 
 
 
 
 
 
1a77428
 
f6fde6f
 
1a77428
 
f6fde6f
1a77428
 
 
 
 
 
 
 
f6fde6f
1a77428
 
f6fde6f
1a77428
f6fde6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# app.py
import os
import torch
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")

BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
PEFT_MODEL_ID = "befm/Be.FM-8B"

USE_PEFT = True
try:
    from peft import PeftModel, PeftConfig  # noqa
except Exception:
    USE_PEFT = False
    print("[WARN] 'peft' not installed; running base model only.")

def load_model_and_tokenizer():
    if HF_TOKEN is None:
        raise RuntimeError(
            "HF_TOKEN is not set. Add it in Space → Settings → Secrets. "
            "Also ensure your account has access to the gated base model."
        )
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        device_map="auto" if torch.cuda.is_available() else None,
        torch_dtype=dtype,
        token=HF_TOKEN,
    )

    if USE_PEFT:
        try:
            _ = PeftConfig.from_pretrained(PEFT_MODEL_ID, token=HF_TOKEN)
            model = PeftModel.from_pretrained(base, PEFT_MODEL_ID, token=HF_TOKEN)
            print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
            return model, tok
        except Exception as e:
            print(f"[WARN] Failed to load PEFT adapter: {e}")
            return base, tok
    return base, tok

model, tokenizer = load_model_and_tokenizer()
DEVICE = model.device

@spaces.GPU
@torch.inference_mode()
def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
    # Apply Llama 3.1 chat template
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    enc = {k: v.to(DEVICE) for k, v in enc.items()}

    input_length = enc['input_ids'].shape[1]
    out = model.generate(
        **enc,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        pad_token_id=tokenizer.eos_token_id,
    )
    # Decode only the newly generated tokens
    return tokenizer.decode(out[0][input_length:], skip_special_tokens=True)

def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
    # Build conversation in Llama 3.1 chat format
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    for user_msg, assistant_msg in (history or []):
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})

    if message:
        messages.append({"role": "user", "content": message})

    reply = generate_response(
        messages,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
    )
    return reply

demo = gr.ChatInterface(
    fn=lambda message, history, system_prompt, max_new_tokens, temperature, top_p:
        chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p),
    additional_inputs=[
        gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2),
        gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
        gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
        gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
    ],
    title="Be.FM-8B (PEFT) on Meta-Llama-3.1-8B-Instruct",
    description="Chat interface using Meta-Llama-3.1-8B-Instruct with PEFT adapter befm/Be.FM-8B."
)

if __name__ == "__main__":
    demo.launch()