# app.py import os import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel, PeftConfig # ---- CONFIG ---- ADAPTER_REPO = "richardprobe/opt-350-chris-adapter" # your LoRA repo ADAPTER_NAME = "finetune_adapter" # how you saved it SYSTEM_PROMPT = "You are Richard. Be concise and casual." # If the adapter is private on the Hub, set HF_TOKEN in the Space secrets HF_TOKEN = os.getenv("HF_TOKEN", None) # ------------- Loading ------------- def load_model_and_tokenizer(): # Inspect adapter to get its base print("Reading adapter config...") peft_cfg = PeftConfig.from_pretrained(ADAPTER_REPO, token=HF_TOKEN) base_id = peft_cfg.base_model_name_or_path print(f"Base model detected: {base_id}") # Tokenizer from base (adapter may also carry added tokens) print("Loading tokenizer...") tok = AutoTokenizer.from_pretrained(base_id, use_fast=True, token=HF_TOKEN) # Safety: many decoder-only models don't define a pad token if tok.pad_token is None and tok.eos_token is not None: tok.pad_token = tok.eos_token tok.padding_side = "right" # Non-quantized load so we can merge print("Loading base model...") dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 base = AutoModelForCausalLM.from_pretrained( base_id, torch_dtype=dtype, device_map="auto", token=HF_TOKEN ) print("Loading adapter and merging...") peft = PeftModel.from_pretrained( base, ADAPTER_REPO, adapter_name=ADAPTER_NAME, token=HF_TOKEN ) # This bakes LoRA weights into the base weights and returns a plain model merged = peft.merge_and_unload() # equivalent to merge_adapter + unload merged.eval() # We’ll use <|end|> as EOS if it exists try: end_id = tok.convert_tokens_to_ids("<|end|>") if end_id is not None and end_id != tok.unk_token_id: merged.config.eos_token_id = end_id except Exception: pass return tok, merged tokenizer, model = load_model_and_tokenizer() # ------------- Prompt building ------------- def build_prompt(history, user_msg): """ Render your chat format using the added tokens that were used during training. History is a list of (user, assistant) tuples from ChatInterface. """ segments = [] if SYSTEM_PROMPT: # If you trained with a system token, add it here. Otherwise keep as plain text. segments.append(f"<|system|>{SYSTEM_PROMPT}<|end|>") for u, a in history or []: if u: segments.append(f"<|user|>{u}<|end|>") if a: segments.append(f"<|assistant|>{a}<|end|>") segments.append(f"<|user|>{user_msg}<|end|>") segments.append("<|assistant|>") return "\n".join(segments) # ------------- Inference ------------- def chat_generate(message, history, temperature=0.7, top_p=0.95, max_new_tokens=256, repetition_penalty=1.1): prompt = build_prompt(history, message) inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} gen_kwargs = dict( max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), do_sample=float(temperature) > 0, repetition_penalty=float(repetition_penalty), eos_token_id=getattr(model.config, "eos_token_id", tokenizer.eos_token_id), pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, ) with torch.inference_mode(): out = model.generate(**inputs, **gen_kwargs) # Return only the assistant part gen_tokens = out[0][inputs["input_ids"].shape[-1]:] text = tokenizer.decode(gen_tokens, skip_special_tokens=True, errors="ignore") # If your <|end|> isn’t marked as special, strip it manually text = text.replace("<|end|>", "").strip() return text # ------------- UI ------------- demo = gr.ChatInterface( fn=chat_generate, title="OPT-350M + LoRA (Chris style)", description="Loads the base model from the adapter's config, merges LoRA, and chats using your training tokens.", additional_inputs=[ gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature"), gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-p"), gr.Slider(16, 512, value=256, step=16, label="Max new tokens"), gr.Slider(1.0, 1.5, value=1.1, step=0.05, label="Repetition penalty"), ], examples=[ ["What are you up to?", 0.7, 0.95, 256, 1.1], ["You coming?", 0.7, 0.95, 256, 1.1], ["I'm on the can", 0.7, 0.95, 256, 1.1], ], cache_examples=False, ) if __name__ == "__main__": # queue helps avoid device contention; hide API to avoid schema issues demo.queue(max_size=8) demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, show_error=True)