File size: 4,251 Bytes
e9462cd
 
00b3a52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations

from typing import List, Optional

try:
    from transformers import pipeline
except Exception:
    pipeline = None

from models import RetrievedChunk


class GeneratorEngine:
    def __init__(self, model_name: str = "google/flan-t5-small"):
        self.model_name = model_name
        self.pipe = None

        if pipeline is not None:
            try:
                self.pipe = pipeline("text2text-generation", model=model_name)
            except Exception:
                self.pipe = None

    def available(self) -> bool:
        return self.pipe is not None

    def _notes_block(self, retrieval_context: List[RetrievedChunk]) -> str:
        if not retrieval_context:
            return ""
        lines = []
        for chunk in retrieval_context[:3]:
            text = (chunk.text or "").strip().replace("\n", " ")
            if len(text) > 220:
                text = text[:217].rstrip() + "…"
            lines.append(f"- {chunk.topic}: {text}")
        return "\n".join(lines)

    def _template_fallback(
        self,
        user_text: str,
        question_text: Optional[str],
        topic: str,
        intent: str,
        retrieval_context: Optional[List[RetrievedChunk]] = None,
    ) -> str:
        notes = self._notes_block(retrieval_context or [])

        if intent == "hint":
            base = "Start by identifying the exact relationship between the quantities before doing any arithmetic."
        elif intent in {"instruction", "method"}:
            base = "Translate the wording into an equation, ratio, or percent relationship, then solve one step at a time."
        elif intent in {"walkthrough", "step_by_step", "explain", "concept"}:
            base = "First identify what the question is asking, then map the values into the correct quantitative structure, and only then compute."
        else:
            base = "This does not match a strong solver rule yet, so begin by identifying the target quantity and the relationship connecting the numbers."

        if notes:
            return f"{base}\n\nRelevant notes:\n{notes}"
        return base

    def _build_prompt(
        self,
        user_text: str,
        question_text: Optional[str],
        topic: str,
        intent: str,
        retrieval_context: Optional[List[RetrievedChunk]] = None,
    ) -> str:
        question = (question_text or user_text or "").strip()
        notes = self._notes_block(retrieval_context or [])

        prompt = [
            "You are a concise GMAT tutor.",
            f"Topic: {topic or 'general'}",
            f"Intent: {intent or 'answer'}",
            "",
            f"Question: {question}",
        ]

        if notes:
            prompt.extend(["", "Relevant teaching notes:", notes])

        prompt.extend(
            [
                "",
                "Respond briefly and clearly.",
                "If the problem is not fully solvable from the parse, give the next best method step.",
                "Do not invent facts.",
            ]
        )

        return "\n".join(prompt)

    def generate(
        self,
        user_text: str,
        question_text: Optional[str] = None,
        topic: str = "",
        intent: str = "answer",
        retrieval_context: Optional[List[RetrievedChunk]] = None,
        chat_history=None,
        max_new_tokens: int = 96,
        **kwargs,
    ) -> Optional[str]:
        prompt = self._build_prompt(
            user_text=user_text,
            question_text=question_text,
            topic=topic,
            intent=intent,
            retrieval_context=retrieval_context or [],
        )

        if self.pipe is not None:
            try:
                out = self.pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False)
                if out and isinstance(out, list):
                    text = str(out[0].get("generated_text", "")).strip()
                    if text:
                        return text
            except Exception:
                pass

        return self._template_fallback(
            user_text=user_text,
            question_text=question_text,
            topic=topic,
            intent=intent,
            retrieval_context=retrieval_context or [],
        )