File size: 4,628 Bytes
54b21c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Run DecomposeRL-7B on a (claim, evidence_doc) pair and pretty-print the trace.

Usage:
    python example.py
"""

import re

from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "dipta007/decomposeRL-7b"

PROMPT_TEMPLATE = """You are tasked with systematically verifying the accuracy of a claim. You will be provided with a claim to verify and an evidence document to consult.

Here is the evidence document you should consult:

<evidence_document>
{evidence_doc}
</evidence_document>

Here is the claim you need to verify:

<claim>
{claim}
</claim>

Your task is to verify whether this claim is Supported or Refuted through an iterative process of asking questions and gathering information.

# Verification Process

Begin by analyzing the claim in <think> tags, then enter an iterative cycle of <question>/<answer> pairs answered ONLY from the evidence document. When every sub-claim is addressed, output your final label inside <verification> tags. The label must be exactly one of: Supported, Refuted.

Stop immediately after the closing </verification> tag.

Begin your verification process now."""

TAG_RE = re.compile(r"<(think|question|answer|verification)>(.*?)</\1>", re.DOTALL)


def build_prompt(claim: str, evidence_doc: str) -> str:
    """Wrap a claim and evidence document in the DecomposeRL verification prompt."""
    return PROMPT_TEMPLATE.format(claim=claim, evidence_doc=evidence_doc)


def verify(
    model,
    tokenizer,
    claim: str,
    evidence_doc: str,
    max_new_tokens: int = 4500,
    temperature: float = 0.7,
) -> str:
    """Run the model end-to-end on a (claim, evidence_doc) pair and return the raw trace."""
    messages = [{"role": "user", "content": build_prompt(claim, evidence_doc)}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt").to(model.device)
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
    )
    return tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)


def parse_trace(text: str):
    """Return a list of (tag, content) tuples in the order they appear."""
    return [(tag, body.strip()) for tag, body in TAG_RE.findall(text)]


def pretty_print(text: str) -> None:
    """Print the trace as a readable conversation. Falls back to raw output if degenerate."""
    parsed = parse_trace(text)
    tags = {tag for tag, _ in parsed}
    if not parsed or "verification" not in tags:
        print("⚠️  Could not parse output into the expected think/question/answer/verification structure.")
        print("Raw output:")
        print("─" * 78)
        print(text)
        print("─" * 78)
        return

    cycle_idx = 0
    pending_q = None
    for tag, body in parsed:
        if tag == "think":
            print("─" * 78)
            print("🧠  THINK")
            print("─" * 78)
            print(body)
            print()
        elif tag == "question":
            cycle_idx += 1
            pending_q = body
        elif tag == "answer":
            print(f"🔸  Q{cycle_idx}: {pending_q}")
            print(f"💬  A{cycle_idx}: {body}")
            print()
            pending_q = None
        elif tag == "verification":
            print("=" * 78)
            print(f"✅  VERIFICATION: {body}")
            print("=" * 78)


def extract_label(text: str):
    """Return 'Supported', 'Refuted', or None."""
    match = re.search(r"<verification>\s*(Supported|Refuted)\s*</verification>", text)
    return match.group(1) if match else None


def main():
    print(f"Loading {MODEL_NAME} ...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype="auto",
        device_map="auto",
    )

    evidence_doc = (
        "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, "
        "France. It is named after the engineer Gustave Eiffel, whose company designed and "
        "built the tower from 1887 to 1889. Locally nicknamed 'La dame de fer', it was "
        "constructed as the centerpiece of the 1889 World's Fair. The tower is 330 metres "
        "(1,083 ft) tall."
    )
    claim = "The Eiffel Tower was completed in 1887 and stands 330 metres tall."

    print(f"\nClaim: {claim}\n")
    response = verify(model, tokenizer, claim, evidence_doc)

    pretty_print(response)
    print(f"\nFinal label: {extract_label(response)}")


if __name__ == "__main__":
    main()