Spaces:
Runtime error
Runtime error
File size: 4,954 Bytes
6a6269f c259870 b89c575 6a6269f dc50374 6a6269f c259870 dc50374 6a6269f c259870 6a6269f b89c575 6a6269f c259870 6a6269f b89c575 6a6269f c259870 6a6269f 03723d8 6a6269f 03723d8 6a6269f c259870 6a6269f dc50374 c259870 b89c575 c259870 03723d8 c259870 6a6269f c259870 dc50374 03723d8 6a6269f c259870 6a6269f dc50374 6a6269f dc50374 c259870 6a6269f 03723d8 c259870 4b3cc7f 03723d8 8e6d217 6a6269f dc50374 6a6269f 49295f2 03723d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# app.py
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
# ---- CONFIG ----
ADAPTER_REPO = "richardprobe/opt-350-chris-adapter" # your LoRA repo
ADAPTER_NAME = "finetune_adapter" # how you saved it
SYSTEM_PROMPT = "You are Richard. Be concise and casual."
# If the adapter is private on the Hub, set HF_TOKEN in the Space secrets
HF_TOKEN = os.getenv("HF_TOKEN", None)
# ------------- Loading -------------
def load_model_and_tokenizer():
# Inspect adapter to get its base
print("Reading adapter config...")
peft_cfg = PeftConfig.from_pretrained(ADAPTER_REPO, token=HF_TOKEN)
base_id = peft_cfg.base_model_name_or_path
print(f"Base model detected: {base_id}")
# Tokenizer from base (adapter may also carry added tokens)
print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(base_id, use_fast=True, token=HF_TOKEN)
# Safety: many decoder-only models don't define a pad token
if tok.pad_token is None and tok.eos_token is not None:
tok.pad_token = tok.eos_token
tok.padding_side = "right"
# Non-quantized load so we can merge
print("Loading base model...")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
base = AutoModelForCausalLM.from_pretrained(
base_id, torch_dtype=dtype, device_map="auto", token=HF_TOKEN
)
print("Loading adapter and merging...")
peft = PeftModel.from_pretrained(
base, ADAPTER_REPO, adapter_name=ADAPTER_NAME, token=HF_TOKEN
)
# This bakes LoRA weights into the base weights and returns a plain model
merged = peft.merge_and_unload() # equivalent to merge_adapter + unload
merged.eval()
# We’ll use <|end|> as EOS if it exists
try:
end_id = tok.convert_tokens_to_ids("<|end|>")
if end_id is not None and end_id != tok.unk_token_id:
merged.config.eos_token_id = end_id
except Exception:
pass
return tok, merged
tokenizer, model = load_model_and_tokenizer()
# ------------- Prompt building -------------
def build_prompt(history, user_msg):
"""
Render your chat format using the added tokens that were used during training.
History is a list of (user, assistant) tuples from ChatInterface.
"""
segments = []
if SYSTEM_PROMPT:
# If you trained with a system token, add it here. Otherwise keep as plain text.
segments.append(f"<|system|>{SYSTEM_PROMPT}<|end|>")
for u, a in history or []:
if u:
segments.append(f"<|user|>{u}<|end|>")
if a:
segments.append(f"<|assistant|>{a}<|end|>")
segments.append(f"<|user|>{user_msg}<|end|>")
segments.append("<|assistant|>")
return "\n".join(segments)
# ------------- Inference -------------
def chat_generate(message, history, temperature=0.7, top_p=0.95, max_new_tokens=256, repetition_penalty=1.1):
prompt = build_prompt(history, message)
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen_kwargs = dict(
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=float(temperature) > 0,
repetition_penalty=float(repetition_penalty),
eos_token_id=getattr(model.config, "eos_token_id", tokenizer.eos_token_id),
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
with torch.inference_mode():
out = model.generate(**inputs, **gen_kwargs)
# Return only the assistant part
gen_tokens = out[0][inputs["input_ids"].shape[-1]:]
text = tokenizer.decode(gen_tokens, skip_special_tokens=True, errors="ignore")
# If your <|end|> isn’t marked as special, strip it manually
text = text.replace("<|end|>", "").strip()
return text
# ------------- UI -------------
demo = gr.ChatInterface(
fn=chat_generate,
title="OPT-350M + LoRA (Chris style)",
description="Loads the base model from the adapter's config, merges LoRA, and chats using your training tokens.",
additional_inputs=[
gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-p"),
gr.Slider(16, 512, value=256, step=16, label="Max new tokens"),
gr.Slider(1.0, 1.5, value=1.1, step=0.05, label="Repetition penalty"),
],
examples=[
["What are you up to?", 0.7, 0.95, 256, 1.1],
["You coming?", 0.7, 0.95, 256, 1.1],
["I'm on the can", 0.7, 0.95, 256, 1.1],
],
cache_examples=False,
)
if __name__ == "__main__":
# queue helps avoid device contention; hide API to avoid schema issues
demo.queue(max_size=8)
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, show_error=True)
|