topabaem's picture
Tune default search grounding depth for Space chat
0ba6762 verified
"""Gradio Space app for Gemma 4 text chat."""
from __future__ import annotations
import os
import sys
from collections.abc import Iterator
from pathlib import Path
from threading import Thread
import gradio as gr
import torch
SPACE_APP_DIR = Path(__file__).resolve().parent
if str(SPACE_APP_DIR) not in sys.path:
sys.path.insert(0, str(SPACE_APP_DIR))
from search_backend import (
format_search_grounding,
format_search_markdown,
search_project_notes,
serialize_hits,
)
try:
import spaces
except ImportError: # pragma: no cover
class _SpacesCompat:
@staticmethod
def GPU(duration: int | None = None):
def _decorator(fn):
return fn
return _decorator
spaces = _SpacesCompat()
MODEL_ID = "google/gemma-4-e4b-it"
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10000"))
DEFAULT_SYSTEM_PROMPT = (
"You are a precise assistant helping evaluate a local memory harness project. "
"Answer clearly and briefly."
)
processor = None
model = None
LOAD_ERROR = None
def _ensure_model_loaded() -> None:
global processor, model, LOAD_ERROR
if processor is not None and model is not None:
return
if os.getenv("SPACE_DISABLE_MODEL_LOAD") == "1":
LOAD_ERROR = "Model loading disabled for local test mode."
return
try:
from transformers import AutoProcessor
try:
from transformers import AutoModelForMultimodalLM as ModelLoader
except ImportError:
from transformers import AutoModelForImageTextToText as ModelLoader
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)
model = ModelLoader.from_pretrained(
MODEL_ID,
device_map="auto",
dtype=torch.bfloat16,
)
except Exception as exc: # noqa: BLE001
LOAD_ERROR = str(exc)
def _build_messages(message: str, history: list, system_prompt: str) -> list[dict]:
messages: list[dict] = []
if system_prompt.strip():
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]})
for item in history:
if isinstance(item, dict):
role = item["role"]
content = item["content"]
if isinstance(content, list):
text = " ".join(part.get("text", "") for part in content if isinstance(part, dict))
else:
text = str(content)
else:
user_text, assistant_text = item
messages.append({"role": "user", "content": [{"type": "text", "text": str(user_text)}]})
if assistant_text:
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(assistant_text)}]})
continue
if role == "user":
messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
else:
messages.append({"role": "assistant", "content": [{"type": "text", "text": text}]})
messages.append({"role": "user", "content": [{"type": "text", "text": message}]})
return messages
def _inject_search_context(message: str, system_prompt: str, enable_search: bool, search_top_k: int) -> tuple[str, list]:
if not enable_search:
return system_prompt, []
hits = search_project_notes(message, top_k=search_top_k)
grounding = format_search_grounding(hits)
if not grounding:
return system_prompt, hits
if system_prompt.strip():
return f"{system_prompt.strip()}\n\n{grounding}", hits
return grounding, hits
@spaces.GPU(duration=120)
@torch.inference_mode()
def _stream_generate(messages: list[dict], max_new_tokens: int, temperature: float) -> Iterator[str]:
_ensure_model_loaded()
if processor is None or model is None:
raise gr.Error(f"Gemma 4 model is not loaded. {LOAD_ERROR or 'Unknown load failure.'}")
from transformers.generation.streamers import TextIteratorStreamer
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
)
n_tokens = inputs["input_ids"].shape[1]
if n_tokens > MAX_INPUT_TOKENS:
raise gr.Error(f"Input too long ({n_tokens} tokens). Maximum is {MAX_INPUT_TOKENS}.")
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
generation_temperature = max(float(temperature), 1e-5)
kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": generation_temperature,
"do_sample": temperature > 0,
"disable_compile": True,
}
errors: list[Exception] = []
def _run() -> None:
try:
model.generate(**kwargs)
except Exception as exc: # noqa: BLE001
errors.append(exc)
thread = Thread(target=_run)
thread.start()
chunks: list[str] = []
for chunk in streamer:
chunks.append(chunk)
yield "".join(chunks)
thread.join()
if errors:
raise gr.Error(f"Generation failed: {errors[0]}")
def chat(message: str, history: list[dict], system_prompt: str, max_new_tokens: int, temperature: float):
if not message.strip():
raise gr.Error("Please enter a message.")
messages = _build_messages(message, history, system_prompt)
yield from _stream_generate(messages, max_new_tokens=max_new_tokens, temperature=temperature)
def grounded_chat(
message: str,
history: list[dict],
system_prompt: str,
max_new_tokens: int,
temperature: float,
enable_search: bool,
search_top_k: int,
):
if not message.strip():
raise gr.Error("Please enter a message.")
grounded_prompt, _ = _inject_search_context(message, system_prompt, enable_search, search_top_k)
messages = _build_messages(message, history, grounded_prompt)
yield from _stream_generate(messages, max_new_tokens=max_new_tokens, temperature=temperature)
def run_search(query: str, top_k: int) -> tuple[str, str]:
hits = search_project_notes(query, top_k=top_k)
return format_search_markdown(hits), serialize_hits(hits)
if os.getenv("SPACE_DISABLE_MODEL_LOAD") != "1":
_ensure_model_loaded()
with gr.Blocks() as demo:
gr.Markdown(
"""
# Memory Harness Gemma 4
Text-only Hugging Face Space using `google/gemma-4-E4B-it`.
This Space is meant to validate that the repository can be paired with a hosted Hugging Face Gemma 4 demo UI.
It now includes lightweight project-note search so the model can ground answers before responding.
"""
)
with gr.Tab("Chat"):
with gr.Row():
system_prompt = gr.Textbox(
value=DEFAULT_SYSTEM_PROMPT,
label="System prompt",
lines=3,
)
with gr.Row():
max_new_tokens = gr.Slider(minimum=64, maximum=1024, value=256, step=32, label="Max new tokens")
temperature = gr.Slider(minimum=0.0, maximum=1.2, value=0.2, step=0.1, label="Temperature")
with gr.Row():
enable_search = gr.Checkbox(value=True, label="Enable project search grounding")
search_top_k = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Search top-k")
gr.ChatInterface(
fn=grounded_chat,
additional_inputs=[system_prompt, max_new_tokens, temperature, enable_search, search_top_k],
title=None,
description=None,
cache_examples=False,
run_examples_on_click=False,
examples=[
[
"Summarize why raw archive retrieval matters for exact-date memory questions.",
DEFAULT_SYSTEM_PROMPT,
256,
0.2,
True,
1,
],
[
"Explain the difference between summary memory and fact memory.",
DEFAULT_SYSTEM_PROMPT,
256,
0.2,
True,
1,
],
[
"Give me a concise evaluation rubric for a memory harness.",
DEFAULT_SYSTEM_PROMPT,
256,
0.2,
True,
1,
],
],
)
with gr.Tab("Search"):
gr.Markdown("Search the bundled project notes directly and inspect what the model can use for grounding.")
search_query = gr.Textbox(label="Search query", lines=2, placeholder="Ask about raw archive, summary memory, evaluation, or architecture.")
search_results_top_k = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Top-k results")
search_button = gr.Button("Run search", variant="primary")
search_markdown = gr.Markdown(value="No search results yet.")
search_raw = gr.Textbox(label="Raw result JSON", lines=12)
search_button.click(run_search, inputs=[search_query, search_results_top_k], outputs=[search_markdown, search_raw], api_name="search")
if __name__ == "__main__":
demo.launch()