Georgefifth's picture
Upload folder using huggingface_hub
a2a263b verified
Raw
History Blame Contribute Delete
4.43 kB
import re, time, json, os, shutil, torch, gradio as gr
import tempfile
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from huggingface_hub import snapshot_download
BASE_ID = "openbmb/MiniCPM5-1B"
ADAPTER_ID = "Georgefifth/tiny-browser-planner-reason"
print("Loading model (this may take a minute)...")
start = time.time()
# Download adapter and create clean config
adapter_dir = os.path.join(tempfile.gettempdir(), "adapter")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="Georgefifth/tiny-browser-planner-reason", local_dir=adapter_dir)
with open(os.path.join(adapter_dir, "adapter_config.json")) as f:
raw_cfg = json.load(f)
KEEP = {"r","lora_alpha","lora_dropout","target_modules","bias","task_type","peft_type","inference_mode"}
clean_cfg = {k: v for k, v in json.load(open(os.path.join(adapter_dir, "adapter_config.json"))).items() if k in {"r","lora_alpha","lora_dropout","target_modules","bias","task_type","peft_type","inference_mode"}}
clean_dir = os.path.join(tempfile.gettempdir(), "clean_adapter")
os.makedirs(clean_dir, exist_ok=True)
import shutil
for fname in os.listdir(adapter_dir):
src = os.path.join(adapter_dir, fname)
dst = os.path.join(clean_dir, fname)
if os.path.isfile(src):
if fname == "adapter_config.json":
with open(dst, "w") as f:
json.dump({"r":16,"lora_alpha":16,"lora_dropout":0,"target_modules":["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],"bias":"none","task_type":"CAUSAL_LM","peft_type":"LORA","inference_mode":True}, f)
else:
shutil.copy2(src, dst)
model = AutoModelForCausalLM.from_pretrained("openbmb/MiniCPM5-1B", torch_dtype=torch.float16, trust_remote_code=True)
model = PeftModel.from_pretrained(model, clean_dir)
tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM5-1B", trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Ready! ({time.time()-start:.0f}s)")
ACTIONS = ["search", "open_page", "extract", "refine_search", "back", "finish"]
def predict(task, history_text):
if not task or not task.strip():
return "Error: task is empty", ""
history = [l.strip() for l in history_text.strip().split("\n") if l.strip()]
hist_str = "\n".join(history)
msgs = [
{"role": "system", "content": "You are a browser planner. First reason about the situation, then output the next action."},
{"role": "user", "content": f"Task: {task}\n\nHistory:\n{hist_str}\n\nWhat is the next action?"},
]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")
input_len = inputs["input_ids"].shape[1]
inputs.pop("token_type_ids", None)
outs = model.generate(**inputs, max_new_tokens=64, temperature=0.01, do_sample=False,
pad_token_id=tokenizer.eos_token_id)
output = tokenizer.decode(outs[0][input_len:], skip_special_tokens=True).strip()
reason_m = re.search(r"Reason:\s*(.+?)(?:\n|$)", output)
action_m = re.search(r"Action:\s*(\S+)", output)
reason = reason_m.group(1).strip() if reason_m else "?"
action = action_m.group(1).strip().lower() if action_m else "?"
if action not in ["search", "open_page", "extract", "refine_search", "back", "finish"]:
action = f"{action} (unknown)"
return reason, action
with gr.Blocks(title="Tiny Browser Planner", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Tiny Browser Planner — Reason-First
MiniCPM5-1B + LoRA | Actions: `search`, `open_page`, `extract`, `refine_search`, `back`, `finish`
""")
with gr.Row():
with gr.Column(scale=2):
task = gr.Textbox(label="Task", placeholder="e.g. Find Apple stock price")
history = gr.Textbox(label="History (one action per line)", lines=5,
placeholder="[search] Search completed.\n[open_page] Page content here...")
btn = gr.Button("Predict", variant="primary")
with gr.Column(scale=1):
reason = gr.Textbox(label="Reason", lines=3, interactive=False)
action = gr.Textbox(label="Next Action", lines=1, interactive=False)
btn.click(fn=predict, inputs=[task, history], outputs=[reason, action])
if __name__ == "__main__":
demo.launch()