File size: 6,990 Bytes
74f1bed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#!/usr/bin/env python3
"""
Grogu Science MoE - Simple Inference Example

This script demonstrates how to use the Grogu debate system
for graduate-level science questions.
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from typing import Dict, List, Optional
import json


class GroguDebateSystem:
    """
    A Mixture-of-Experts debate system for science questions.

    The system uses 4 agents:
    - Grogu: General reasoning agent (Nemotron-Qwen-1.5B + LoRA)
    - Physics: Domain specialist
    - Chemistry: Domain specialist
    - Biology: Domain specialist

    They engage in a 2-round collaborative debate with synthesis.
    """

    def __init__(
        self,
        grogu_path: str = "zenith-global/grogu-science-moe/grogu-lora",
        base_model: str = "nvidia/nemotron-qwen-1.5b",
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        dtype: torch.dtype = torch.float16
    ):
        """Initialize the debate system with model paths."""
        self.device = device
        self.dtype = dtype

        print("Loading Grogu model...")
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)

        # Load base model
        self.base_model = AutoModelForCausalLM.from_pretrained(
            base_model,
            torch_dtype=dtype,
            device_map="auto"
        )

        # Apply LoRA weights
        self.grogu = PeftModel.from_pretrained(
            self.base_model,
            grogu_path
        )
        self.grogu.eval()
        print("Grogu loaded!")

    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.1,
        top_p: float = 0.95
    ) -> str:
        """Generate a response from Grogu."""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.grogu.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=temperature > 0,
                pad_token_id=self.tokenizer.eos_token_id
            )

        response = self.tokenizer.decode(
            outputs[0][inputs.input_ids.shape[1]:],
            skip_special_tokens=True
        )
        return response.strip()

    def extract_answer(self, response: str) -> str:
        """Extract the final answer (A, B, C, or D) from a response."""
        # Look for common patterns
        patterns = [
            "The answer is",
            "Answer:",
            "Final answer:",
            "Therefore,",
            "So the answer is"
        ]

        response_upper = response.upper()

        for pattern in patterns:
            if pattern.upper() in response_upper:
                idx = response_upper.find(pattern.upper())
                after = response_upper[idx:]
                for char in ['A', 'B', 'C', 'D']:
                    if char in after[:50]:
                        return char

        # Fallback: find any standalone A/B/C/D
        for char in ['A', 'B', 'C', 'D']:
            if f" {char}" in response_upper or f"({char})" in response_upper:
                return char

        return "A"  # Default

    def debate_single(
        self,
        question: str,
        options: Optional[Dict[str, str]] = None
    ) -> Dict:
        """
        Run a simplified single-agent debate (Grogu only).

        For the full 4-agent debate, see run_debate.py
        """
        # Format the question
        if options:
            formatted_options = "\n".join([
                f"{k}) {v}" for k, v in options.items()
            ])
            prompt = f"""You are an expert scientist. Answer this question step by step.

Question: {question}

Options:
{formatted_options}

Think through this carefully and provide your answer. End with "The answer is [A/B/C/D]"."""
        else:
            prompt = f"""You are an expert scientist. Answer this question step by step.

Question: {question}

Think through this carefully and provide your reasoning."""

        # Round 1: Initial reasoning
        response_r1 = self.generate(prompt)
        answer_r1 = self.extract_answer(response_r1)

        # Round 2: Self-reflection
        reflection_prompt = f"""{prompt}

Your initial answer was: {answer_r1}
Your reasoning was: {response_r1[:500]}...

Reconsider your answer. Are you confident? If not, what might be wrong?
End with "The answer is [A/B/C/D]"."""

        response_r2 = self.generate(reflection_prompt)
        answer_r2 = self.extract_answer(response_r2)

        return {
            "question": question,
            "round1_answer": answer_r1,
            "round1_reasoning": response_r1,
            "round2_answer": answer_r2,
            "round2_reasoning": response_r2,
            "final_answer": answer_r2,
            "changed_mind": answer_r1 != answer_r2
        }

    @classmethod
    def from_pretrained(cls, model_path: str) -> "GroguDebateSystem":
        """Load the debate system from a HuggingFace model path."""
        return cls(grogu_path=f"{model_path}/grogu-lora")


def main():
    """Example usage of the Grogu debate system."""

    # Example GPQA-style question
    question = """
    Two quantum states with energies E1 and E2 have a lifetime of 10^-9 sec
    and 10^-8 sec, respectively. We want to clearly distinguish these two
    energy levels. Which one of the following options could be their energy
    difference so that they can be clearly resolved?
    """

    options = {
        "A": "10^-4 eV",
        "B": "10^-11 eV",
        "C": "10^-8 eV",
        "D": "10^-9 eV"
    }

    print("=" * 60)
    print("GROGU SCIENCE MoE - Inference Demo")
    print("=" * 60)

    # Initialize system
    print("\nInitializing Grogu...")

    # For demo, we'll use a mock if the model isn't available
    try:
        system = GroguDebateSystem()

        print("\nRunning debate on physics question...")
        result = system.debate_single(question, options)

        print(f"\nQuestion: {question[:100]}...")
        print(f"\nRound 1 Answer: {result['round1_answer']}")
        print(f"Round 2 Answer: {result['round2_answer']}")
        print(f"Changed Mind: {result['changed_mind']}")
        print(f"\nFinal Answer: {result['final_answer']}")

        # Correct answer is A (10^-4 eV)
        correct = "A"
        print(f"Correct Answer: {correct}")
        print(f"Result: {'Correct!' if result['final_answer'] == correct else 'Incorrect'}")

    except Exception as e:
        print(f"\nNote: Running in demo mode (model not loaded)")
        print(f"Error: {e}")
        print("\nTo run inference, ensure you have:")
        print("1. The Grogu LoRA weights downloaded")
        print("2. Sufficient GPU memory (~4GB for Grogu alone)")
        print("3. The base model (nvidia/nemotron-qwen-1.5b)")


if __name__ == "__main__":
    main()