File size: 4,427 Bytes
a9f1b52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import re, time, json, os, shutil, torch, gradio as gr
import tempfile
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from huggingface_hub import snapshot_download

BASE_ID = "openbmb/MiniCPM5-1B"
ADAPTER_ID = "Georgefifth/tiny-browser-planner-reason"

print("Loading model (this may take a minute)...")
start = time.time()

# Download adapter and create clean config
adapter_dir = os.path.join(tempfile.gettempdir(), "adapter")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="Georgefifth/tiny-browser-planner-reason", local_dir=adapter_dir)
with open(os.path.join(adapter_dir, "adapter_config.json")) as f:
    raw_cfg = json.load(f)

KEEP = {"r","lora_alpha","lora_dropout","target_modules","bias","task_type","peft_type","inference_mode"}
clean_cfg = {k: v for k, v in json.load(open(os.path.join(adapter_dir, "adapter_config.json"))).items() if k in {"r","lora_alpha","lora_dropout","target_modules","bias","task_type","peft_type","inference_mode"}}

clean_dir = os.path.join(tempfile.gettempdir(), "clean_adapter")
os.makedirs(clean_dir, exist_ok=True)
import shutil
for fname in os.listdir(adapter_dir):
    src = os.path.join(adapter_dir, fname)
    dst = os.path.join(clean_dir, fname)
    if os.path.isfile(src):
        if fname == "adapter_config.json":
            with open(dst, "w") as f:
                json.dump({"r":16,"lora_alpha":16,"lora_dropout":0,"target_modules":["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],"bias":"none","task_type":"CAUSAL_LM","peft_type":"LORA","inference_mode":True}, f)
        else:
            shutil.copy2(src, dst)

model = AutoModelForCausalLM.from_pretrained("openbmb/MiniCPM5-1B", torch_dtype=torch.float16, trust_remote_code=True)
model = PeftModel.from_pretrained(model, clean_dir)
tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM5-1B", trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f"Ready! ({time.time()-start:.0f}s)")

ACTIONS = ["search", "open_page", "extract", "refine_search", "back", "finish"]

def predict(task, history_text):
    if not task or not task.strip():
        return "Error: task is empty", ""
    history = [l.strip() for l in history_text.strip().split("\n") if l.strip()]
    hist_str = "\n".join(history)
    msgs = [
        {"role": "system", "content": "You are a browser planner. First reason about the situation, then output the next action."},
        {"role": "user", "content": f"Task: {task}\n\nHistory:\n{hist_str}\n\nWhat is the next action?"},
    ]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_len = inputs["input_ids"].shape[1]
    inputs.pop("token_type_ids", None)
    outs = model.generate(**inputs, max_new_tokens=64, temperature=0.01, do_sample=False,
                          pad_token_id=tokenizer.eos_token_id)
    output = tokenizer.decode(outs[0][input_len:], skip_special_tokens=True).strip()
    reason_m = re.search(r"Reason:\s*(.+?)(?:\n|$)", output)
    action_m = re.search(r"Action:\s*(\S+)", output)
    reason = reason_m.group(1).strip() if reason_m else "?"
    action = action_m.group(1).strip().lower() if action_m else "?"
    if action not in ["search", "open_page", "extract", "refine_search", "back", "finish"]:
        action = f"{action} (unknown)"
    return reason, action

with gr.Blocks(title="Tiny Browser Planner", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # Tiny Browser Planner — Reason-First
    MiniCPM5-1B + LoRA  |  Actions: `search`, `open_page`, `extract`, `refine_search`, `back`, `finish`
    """)
    with gr.Row():
        with gr.Column(scale=2):
            task = gr.Textbox(label="Task", placeholder="e.g. Find Apple stock price")
            history = gr.Textbox(label="History (one action per line)", lines=5,
                                 placeholder="[search] Search completed.\n[open_page] Page content here...")
            btn = gr.Button("Predict", variant="primary")
        with gr.Column(scale=1):
            reason = gr.Textbox(label="Reason", lines=3, interactive=False)
            action = gr.Textbox(label="Next Action", lines=1, interactive=False)
    btn.click(fn=predict, inputs=[task, history], outputs=[reason, action])

if __name__ == "__main__":
    demo.launch()