File size: 4,954 Bytes
6a6269f
c259870
 
b89c575
6a6269f
 
dc50374
6a6269f
 
 
c259870
dc50374
6a6269f
 
 
 
 
 
 
 
 
 
 
 
c259870
6a6269f
b89c575
6a6269f
 
 
 
c259870
6a6269f
 
 
 
 
 
b89c575
6a6269f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c259870
6a6269f
 
 
 
 
 
 
 
 
 
 
 
03723d8
6a6269f
03723d8
6a6269f
c259870
6a6269f
 
dc50374
c259870
b89c575
c259870
 
03723d8
c259870
6a6269f
 
c259870
dc50374
03723d8
6a6269f
c259870
6a6269f
 
 
 
 
 
dc50374
6a6269f
dc50374
c259870
6a6269f
 
03723d8
 
 
 
 
 
c259870
4b3cc7f
03723d8
 
8e6d217
6a6269f
dc50374
 
 
6a6269f
49295f2
03723d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# app.py
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig

# ---- CONFIG ----
ADAPTER_REPO = "richardprobe/opt-350-chris-adapter"  # your LoRA repo
ADAPTER_NAME = "finetune_adapter"                    # how you saved it
SYSTEM_PROMPT = "You are Richard. Be concise and casual."

# If the adapter is private on the Hub, set HF_TOKEN in the Space secrets
HF_TOKEN = os.getenv("HF_TOKEN", None)

# ------------- Loading -------------
def load_model_and_tokenizer():
    # Inspect adapter to get its base
    print("Reading adapter config...")
    peft_cfg = PeftConfig.from_pretrained(ADAPTER_REPO, token=HF_TOKEN)
    base_id = peft_cfg.base_model_name_or_path
    print(f"Base model detected: {base_id}")

    # Tokenizer from base (adapter may also carry added tokens)
    print("Loading tokenizer...")
    tok = AutoTokenizer.from_pretrained(base_id, use_fast=True, token=HF_TOKEN)

    # Safety: many decoder-only models don't define a pad token
    if tok.pad_token is None and tok.eos_token is not None:
        tok.pad_token = tok.eos_token
    tok.padding_side = "right"

    # Non-quantized load so we can merge
    print("Loading base model...")
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    base = AutoModelForCausalLM.from_pretrained(
        base_id, torch_dtype=dtype, device_map="auto", token=HF_TOKEN
    )

    print("Loading adapter and merging...")
    peft = PeftModel.from_pretrained(
        base, ADAPTER_REPO, adapter_name=ADAPTER_NAME, token=HF_TOKEN
    )
    # This bakes LoRA weights into the base weights and returns a plain model
    merged = peft.merge_and_unload()  # equivalent to merge_adapter + unload
    merged.eval()

    # We’ll use <|end|> as EOS if it exists
    try:
        end_id = tok.convert_tokens_to_ids("<|end|>")
        if end_id is not None and end_id != tok.unk_token_id:
            merged.config.eos_token_id = end_id
    except Exception:
        pass

    return tok, merged

tokenizer, model = load_model_and_tokenizer()

# ------------- Prompt building -------------
def build_prompt(history, user_msg):
    """
    Render your chat format using the added tokens that were used during training.
    History is a list of (user, assistant) tuples from ChatInterface.
    """
    segments = []
    if SYSTEM_PROMPT:
        # If you trained with a system token, add it here. Otherwise keep as plain text.
        segments.append(f"<|system|>{SYSTEM_PROMPT}<|end|>")

    for u, a in history or []:
        if u:
            segments.append(f"<|user|>{u}<|end|>")
        if a:
            segments.append(f"<|assistant|>{a}<|end|>")

    segments.append(f"<|user|>{user_msg}<|end|>")
    segments.append("<|assistant|>")
    return "\n".join(segments)

# ------------- Inference -------------
def chat_generate(message, history, temperature=0.7, top_p=0.95, max_new_tokens=256, repetition_penalty=1.1):
    prompt = build_prompt(history, message)

    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    gen_kwargs = dict(
        max_new_tokens=int(max_new_tokens),
        temperature=float(temperature),
        top_p=float(top_p),
        do_sample=float(temperature) > 0,
        repetition_penalty=float(repetition_penalty),
        eos_token_id=getattr(model.config, "eos_token_id", tokenizer.eos_token_id),
        pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
    )

    with torch.inference_mode():
        out = model.generate(**inputs, **gen_kwargs)

    # Return only the assistant part
    gen_tokens = out[0][inputs["input_ids"].shape[-1]:]
    text = tokenizer.decode(gen_tokens, skip_special_tokens=True, errors="ignore")
    # If your <|end|> isn’t marked as special, strip it manually
    text = text.replace("<|end|>", "").strip()
    return text

# ------------- UI -------------
demo = gr.ChatInterface(
    fn=chat_generate,
    title="OPT-350M + LoRA (Chris style)",
    description="Loads the base model from the adapter's config, merges LoRA, and chats using your training tokens.",
    additional_inputs=[
        gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-p"),
        gr.Slider(16, 512, value=256, step=16, label="Max new tokens"),
        gr.Slider(1.0, 1.5, value=1.1, step=0.05, label="Repetition penalty"),
    ],
    examples=[
        ["What are you up to?", 0.7, 0.95, 256, 1.1],
        ["You coming?",         0.7, 0.95, 256, 1.1],
        ["I'm on the can",      0.7, 0.95, 256, 1.1],
    ],
    cache_examples=False,
)

if __name__ == "__main__":
    # queue helps avoid device contention; hide API to avoid schema issues
    demo.queue(max_size=8)
    demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, show_error=True)