GameAI / generator_engine.py
j-js's picture
Update generator_engine.py
00b3a52 verified
raw
history blame
4.25 kB
from __future__ import annotations
from typing import List, Optional
try:
from transformers import pipeline
except Exception:
pipeline = None
from models import RetrievedChunk
class GeneratorEngine:
def __init__(self, model_name: str = "google/flan-t5-small"):
self.model_name = model_name
self.pipe = None
if pipeline is not None:
try:
self.pipe = pipeline("text2text-generation", model=model_name)
except Exception:
self.pipe = None
def available(self) -> bool:
return self.pipe is not None
def _notes_block(self, retrieval_context: List[RetrievedChunk]) -> str:
if not retrieval_context:
return ""
lines = []
for chunk in retrieval_context[:3]:
text = (chunk.text or "").strip().replace("\n", " ")
if len(text) > 220:
text = text[:217].rstrip() + "…"
lines.append(f"- {chunk.topic}: {text}")
return "\n".join(lines)
def _template_fallback(
self,
user_text: str,
question_text: Optional[str],
topic: str,
intent: str,
retrieval_context: Optional[List[RetrievedChunk]] = None,
) -> str:
notes = self._notes_block(retrieval_context or [])
if intent == "hint":
base = "Start by identifying the exact relationship between the quantities before doing any arithmetic."
elif intent in {"instruction", "method"}:
base = "Translate the wording into an equation, ratio, or percent relationship, then solve one step at a time."
elif intent in {"walkthrough", "step_by_step", "explain", "concept"}:
base = "First identify what the question is asking, then map the values into the correct quantitative structure, and only then compute."
else:
base = "This does not match a strong solver rule yet, so begin by identifying the target quantity and the relationship connecting the numbers."
if notes:
return f"{base}\n\nRelevant notes:\n{notes}"
return base
def _build_prompt(
self,
user_text: str,
question_text: Optional[str],
topic: str,
intent: str,
retrieval_context: Optional[List[RetrievedChunk]] = None,
) -> str:
question = (question_text or user_text or "").strip()
notes = self._notes_block(retrieval_context or [])
prompt = [
"You are a concise GMAT tutor.",
f"Topic: {topic or 'general'}",
f"Intent: {intent or 'answer'}",
"",
f"Question: {question}",
]
if notes:
prompt.extend(["", "Relevant teaching notes:", notes])
prompt.extend(
[
"",
"Respond briefly and clearly.",
"If the problem is not fully solvable from the parse, give the next best method step.",
"Do not invent facts.",
]
)
return "\n".join(prompt)
def generate(
self,
user_text: str,
question_text: Optional[str] = None,
topic: str = "",
intent: str = "answer",
retrieval_context: Optional[List[RetrievedChunk]] = None,
chat_history=None,
max_new_tokens: int = 96,
**kwargs,
) -> Optional[str]:
prompt = self._build_prompt(
user_text=user_text,
question_text=question_text,
topic=topic,
intent=intent,
retrieval_context=retrieval_context or [],
)
if self.pipe is not None:
try:
out = self.pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False)
if out and isinstance(out, list):
text = str(out[0].get("generated_text", "")).strip()
if text:
return text
except Exception:
pass
return self._template_fallback(
user_text=user_text,
question_text=question_text,
topic=topic,
intent=intent,
retrieval_context=retrieval_context or [],
)