imtrt004
fix: remove exAI
6780118
from __future__ import annotations
import torch
from model.loader import get_tokenizer, get_llm
from transformers import TextIteratorStreamer
from threading import Thread
from typing import Generator, TYPE_CHECKING
if TYPE_CHECKING:
from retrieval.vectorstore import ChunkResult
def _strip_thinking_stream(streamer) -> "Generator[str, None, None]":
"""Filter out <think>...</think> reasoning blocks, yielding only the final answer."""
buffer = ""
in_thinking = False
for token in streamer:
buffer += token
while True:
if not in_thinking:
idx = buffer.find("<think>")
if idx == -1:
# No opening tag β€” yield all but a small lookahead window
safe_len = max(0, len(buffer) - 6)
if safe_len:
yield buffer[:safe_len]
buffer = buffer[safe_len:]
break
else:
if idx > 0:
yield buffer[:idx] # text before <think>
buffer = buffer[idx + len("<think>"):]
in_thinking = True
else:
idx = buffer.find("</think>")
if idx == -1:
buffer = "" # discard thinking content
break
else:
buffer = buffer[idx + len("</think>"):]
in_thinking = False
# flush remainder
if buffer and not in_thinking:
yield buffer
SYSTEM_PROMPT = """You are an expert AI study and research assistant by Md Tusar Akon.
You have access to the user's uploaded document(s) as your primary knowledge source.
CAPABILITIES:
β€’ Solve exam questions, math, and statistical problems step-by-step with full working
β€’ Summarise, explain, and analyse documents thoroughly
β€’ Answer general knowledge questions from your training when they go beyond the document
β€’ Suggest related concepts and insights based on the document content
STRUCTURAL RULES:
β€’ When asked to solve questions, read ALL context chunks and solve EVERY question found
β€’ Number your answers to match question numbers; use proper formatting and math notation
β€’ For multi-part questions, answer each part clearly labelled
CITATION RULES:
When you use information directly from the document context, cite it inline as [[N]]
(e.g., [[1]], [[3]]) immediately after the relevant sentence. Each N corresponds to
[Source N] in the context. Do NOT cite general knowledge from your training.
BEHAVIOUR:
β€’ Document questions β†’ use context first, supplement with your knowledge if needed
β€’ General questions (theory, concepts) β†’ answer fully from your expertise
β€’ Identity / meta questions β†’ answer as a study assistant by Md Tusar Akon
β€’ NEVER say "I couldn't find that in your document" for solvable or general questions
β€’ If context lacks specific detail, supplement with training knowledge and flag it briefly"""
def _build_context(chunks: list) -> str:
"""Format chunks into a numbered context block with source references."""
parts = []
for i, chunk in enumerate(chunks, 1):
text = chunk.text if hasattr(chunk, "text") else str(chunk)
page_number = chunk.page_number if hasattr(chunk, "page_number") else 1
parts.append(f"[Source {i} β€” Page {page_number}]\n{text}")
return "\n\n---\n\n".join(parts)
def stream_answer(
query: str,
context_chunks: list,
thinking_mode: bool = False,
) -> Generator[str, None, None]:
tokenizer = get_tokenizer()
model = get_llm()
context = _build_context(context_chunks)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
)
if not isinstance(input_ids, torch.Tensor):
# Some custom tokenizers return a BatchEncoding/dict; extract tensor.
input_ids = input_ids["input_ids"]
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
attention_mask = (input_ids != tokenizer.pad_token_id).long()
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=None, # No timeout β€” CPU prefill of large docs can take >120s
)
# Capture generate-thread exceptions so the streamer never hangs forever
_gen_exc: list = [None]
def _generate():
try:
model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
streamer=streamer,
max_new_tokens=2048,
do_sample=False, # greedy – fastest on CPU, fully deterministic
pad_token_id=tokenizer.eos_token_id,
)
except Exception as exc:
_gen_exc[0] = exc
# Unblock the streamer consumer so it doesn't wait forever
streamer.text_queue.put(streamer.stop_signal)
thread = Thread(target=_generate, daemon=True)
thread.start()
yield from _strip_thinking_stream(streamer)
thread.join()
if _gen_exc[0] is not None:
raise RuntimeError(f"LLM generation failed: {_gen_exc[0]}") from _gen_exc[0]