File size: 4,427 Bytes
a9f1b52 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import re, time, json, os, shutil, torch, gradio as gr
import tempfile
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from huggingface_hub import snapshot_download
BASE_ID = "openbmb/MiniCPM5-1B"
ADAPTER_ID = "Georgefifth/tiny-browser-planner-reason"
print("Loading model (this may take a minute)...")
start = time.time()
# Download adapter and create clean config
adapter_dir = os.path.join(tempfile.gettempdir(), "adapter")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="Georgefifth/tiny-browser-planner-reason", local_dir=adapter_dir)
with open(os.path.join(adapter_dir, "adapter_config.json")) as f:
raw_cfg = json.load(f)
KEEP = {"r","lora_alpha","lora_dropout","target_modules","bias","task_type","peft_type","inference_mode"}
clean_cfg = {k: v for k, v in json.load(open(os.path.join(adapter_dir, "adapter_config.json"))).items() if k in {"r","lora_alpha","lora_dropout","target_modules","bias","task_type","peft_type","inference_mode"}}
clean_dir = os.path.join(tempfile.gettempdir(), "clean_adapter")
os.makedirs(clean_dir, exist_ok=True)
import shutil
for fname in os.listdir(adapter_dir):
src = os.path.join(adapter_dir, fname)
dst = os.path.join(clean_dir, fname)
if os.path.isfile(src):
if fname == "adapter_config.json":
with open(dst, "w") as f:
json.dump({"r":16,"lora_alpha":16,"lora_dropout":0,"target_modules":["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],"bias":"none","task_type":"CAUSAL_LM","peft_type":"LORA","inference_mode":True}, f)
else:
shutil.copy2(src, dst)
model = AutoModelForCausalLM.from_pretrained("openbmb/MiniCPM5-1B", torch_dtype=torch.float16, trust_remote_code=True)
model = PeftModel.from_pretrained(model, clean_dir)
tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM5-1B", trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Ready! ({time.time()-start:.0f}s)")
ACTIONS = ["search", "open_page", "extract", "refine_search", "back", "finish"]
def predict(task, history_text):
if not task or not task.strip():
return "Error: task is empty", ""
history = [l.strip() for l in history_text.strip().split("\n") if l.strip()]
hist_str = "\n".join(history)
msgs = [
{"role": "system", "content": "You are a browser planner. First reason about the situation, then output the next action."},
{"role": "user", "content": f"Task: {task}\n\nHistory:\n{hist_str}\n\nWhat is the next action?"},
]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")
input_len = inputs["input_ids"].shape[1]
inputs.pop("token_type_ids", None)
outs = model.generate(**inputs, max_new_tokens=64, temperature=0.01, do_sample=False,
pad_token_id=tokenizer.eos_token_id)
output = tokenizer.decode(outs[0][input_len:], skip_special_tokens=True).strip()
reason_m = re.search(r"Reason:\s*(.+?)(?:\n|$)", output)
action_m = re.search(r"Action:\s*(\S+)", output)
reason = reason_m.group(1).strip() if reason_m else "?"
action = action_m.group(1).strip().lower() if action_m else "?"
if action not in ["search", "open_page", "extract", "refine_search", "back", "finish"]:
action = f"{action} (unknown)"
return reason, action
with gr.Blocks(title="Tiny Browser Planner", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Tiny Browser Planner — Reason-First
MiniCPM5-1B + LoRA | Actions: `search`, `open_page`, `extract`, `refine_search`, `back`, `finish`
""")
with gr.Row():
with gr.Column(scale=2):
task = gr.Textbox(label="Task", placeholder="e.g. Find Apple stock price")
history = gr.Textbox(label="History (one action per line)", lines=5,
placeholder="[search] Search completed.\n[open_page] Page content here...")
btn = gr.Button("Predict", variant="primary")
with gr.Column(scale=1):
reason = gr.Textbox(label="Reason", lines=3, interactive=False)
action = gr.Textbox(label="Next Action", lines=1, interactive=False)
btn.click(fn=predict, inputs=[task, history], outputs=[reason, action])
if __name__ == "__main__":
demo.launch() |