|
|
|
|
|
import os |
|
|
import torch |
|
|
import spaces |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
|
|
BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
PEFT_MODEL_ID = "befm/Be.FM-8B" |
|
|
|
|
|
USE_PEFT = True |
|
|
try: |
|
|
from peft import PeftModel, PeftConfig |
|
|
except Exception: |
|
|
USE_PEFT = False |
|
|
print("[WARN] 'peft' not installed; running base model only.") |
|
|
|
|
|
def load_model_and_tokenizer(): |
|
|
if HF_TOKEN is None: |
|
|
raise RuntimeError( |
|
|
"HF_TOKEN is not set. Add it in Space → Settings → Secrets. " |
|
|
"Also ensure your account has access to the gated base model." |
|
|
) |
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN) |
|
|
if tok.pad_token is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
torch_dtype=dtype, |
|
|
token=HF_TOKEN, |
|
|
) |
|
|
|
|
|
if USE_PEFT: |
|
|
try: |
|
|
_ = PeftConfig.from_pretrained(PEFT_MODEL_ID, token=HF_TOKEN) |
|
|
model = PeftModel.from_pretrained(base, PEFT_MODEL_ID, token=HF_TOKEN) |
|
|
print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}") |
|
|
return model, tok |
|
|
except Exception as e: |
|
|
print(f"[WARN] Failed to load PEFT adapter: {e}") |
|
|
return base, tok |
|
|
return base, tok |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer() |
|
|
DEVICE = model.device |
|
|
|
|
|
@spaces.GPU |
|
|
@torch.inference_mode() |
|
|
def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str: |
|
|
|
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
|
|
enc = {k: v.to(DEVICE) for k, v in enc.items()} |
|
|
|
|
|
input_length = enc['input_ids'].shape[1] |
|
|
out = model.generate( |
|
|
**enc, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
return tokenizer.decode(out[0][input_length:], skip_special_tokens=True) |
|
|
|
|
|
def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p): |
|
|
|
|
|
messages = [] |
|
|
if system_prompt: |
|
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
for user_msg, assistant_msg in (history or []): |
|
|
if user_msg: |
|
|
messages.append({"role": "user", "content": user_msg}) |
|
|
if assistant_msg: |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
if message: |
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
reply = generate_response( |
|
|
messages, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
) |
|
|
return reply |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=lambda message, history, system_prompt, max_new_tokens, temperature, top_p: |
|
|
chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p), |
|
|
additional_inputs=[ |
|
|
gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2), |
|
|
gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"), |
|
|
gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"), |
|
|
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"), |
|
|
], |
|
|
title="Be.FM-8B (PEFT) on Meta-Llama-3.1-8B-Instruct", |
|
|
description="Chat interface using Meta-Llama-3.1-8B-Instruct with PEFT adapter befm/Be.FM-8B." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|