"""Gradio Space app for Gemma 4 text chat.""" from __future__ import annotations import os import sys from collections.abc import Iterator from pathlib import Path from threading import Thread import gradio as gr import torch SPACE_APP_DIR = Path(__file__).resolve().parent if str(SPACE_APP_DIR) not in sys.path: sys.path.insert(0, str(SPACE_APP_DIR)) from search_backend import ( format_search_grounding, format_search_markdown, search_project_notes, serialize_hits, ) try: import spaces except ImportError: # pragma: no cover class _SpacesCompat: @staticmethod def GPU(duration: int | None = None): def _decorator(fn): return fn return _decorator spaces = _SpacesCompat() MODEL_ID = "google/gemma-4-e4b-it" MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10000")) DEFAULT_SYSTEM_PROMPT = ( "You are a precise assistant helping evaluate a local memory harness project. " "Answer clearly and briefly." ) processor = None model = None LOAD_ERROR = None def _ensure_model_loaded() -> None: global processor, model, LOAD_ERROR if processor is not None and model is not None: return if os.getenv("SPACE_DISABLE_MODEL_LOAD") == "1": LOAD_ERROR = "Model loading disabled for local test mode." return try: from transformers import AutoProcessor try: from transformers import AutoModelForMultimodalLM as ModelLoader except ImportError: from transformers import AutoModelForImageTextToText as ModelLoader processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False) model = ModelLoader.from_pretrained( MODEL_ID, device_map="auto", dtype=torch.bfloat16, ) except Exception as exc: # noqa: BLE001 LOAD_ERROR = str(exc) def _build_messages(message: str, history: list, system_prompt: str) -> list[dict]: messages: list[dict] = [] if system_prompt.strip(): messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]}) for item in history: if isinstance(item, dict): role = item["role"] content = item["content"] if isinstance(content, list): text = " ".join(part.get("text", "") for part in content if isinstance(part, dict)) else: text = str(content) else: user_text, assistant_text = item messages.append({"role": "user", "content": [{"type": "text", "text": str(user_text)}]}) if assistant_text: messages.append({"role": "assistant", "content": [{"type": "text", "text": str(assistant_text)}]}) continue if role == "user": messages.append({"role": "user", "content": [{"type": "text", "text": text}]}) else: messages.append({"role": "assistant", "content": [{"type": "text", "text": text}]}) messages.append({"role": "user", "content": [{"type": "text", "text": message}]}) return messages def _inject_search_context(message: str, system_prompt: str, enable_search: bool, search_top_k: int) -> tuple[str, list]: if not enable_search: return system_prompt, [] hits = search_project_notes(message, top_k=search_top_k) grounding = format_search_grounding(hits) if not grounding: return system_prompt, hits if system_prompt.strip(): return f"{system_prompt.strip()}\n\n{grounding}", hits return grounding, hits @spaces.GPU(duration=120) @torch.inference_mode() def _stream_generate(messages: list[dict], max_new_tokens: int, temperature: float) -> Iterator[str]: _ensure_model_loaded() if processor is None or model is None: raise gr.Error(f"Gemma 4 model is not loaded. {LOAD_ERROR or 'Unknown load failure.'}") from transformers.generation.streamers import TextIteratorStreamer inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ) n_tokens = inputs["input_ids"].shape[1] if n_tokens > MAX_INPUT_TOKENS: raise gr.Error(f"Input too long ({n_tokens} tokens). Maximum is {MAX_INPUT_TOKENS}.") inputs = inputs.to(device=model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) generation_temperature = max(float(temperature), 1e-5) kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": generation_temperature, "do_sample": temperature > 0, "disable_compile": True, } errors: list[Exception] = [] def _run() -> None: try: model.generate(**kwargs) except Exception as exc: # noqa: BLE001 errors.append(exc) thread = Thread(target=_run) thread.start() chunks: list[str] = [] for chunk in streamer: chunks.append(chunk) yield "".join(chunks) thread.join() if errors: raise gr.Error(f"Generation failed: {errors[0]}") def chat(message: str, history: list[dict], system_prompt: str, max_new_tokens: int, temperature: float): if not message.strip(): raise gr.Error("Please enter a message.") messages = _build_messages(message, history, system_prompt) yield from _stream_generate(messages, max_new_tokens=max_new_tokens, temperature=temperature) def grounded_chat( message: str, history: list[dict], system_prompt: str, max_new_tokens: int, temperature: float, enable_search: bool, search_top_k: int, ): if not message.strip(): raise gr.Error("Please enter a message.") grounded_prompt, _ = _inject_search_context(message, system_prompt, enable_search, search_top_k) messages = _build_messages(message, history, grounded_prompt) yield from _stream_generate(messages, max_new_tokens=max_new_tokens, temperature=temperature) def run_search(query: str, top_k: int) -> tuple[str, str]: hits = search_project_notes(query, top_k=top_k) return format_search_markdown(hits), serialize_hits(hits) if os.getenv("SPACE_DISABLE_MODEL_LOAD") != "1": _ensure_model_loaded() with gr.Blocks() as demo: gr.Markdown( """ # Memory Harness Gemma 4 Text-only Hugging Face Space using `google/gemma-4-E4B-it`. This Space is meant to validate that the repository can be paired with a hosted Hugging Face Gemma 4 demo UI. It now includes lightweight project-note search so the model can ground answers before responding. """ ) with gr.Tab("Chat"): with gr.Row(): system_prompt = gr.Textbox( value=DEFAULT_SYSTEM_PROMPT, label="System prompt", lines=3, ) with gr.Row(): max_new_tokens = gr.Slider(minimum=64, maximum=1024, value=256, step=32, label="Max new tokens") temperature = gr.Slider(minimum=0.0, maximum=1.2, value=0.2, step=0.1, label="Temperature") with gr.Row(): enable_search = gr.Checkbox(value=True, label="Enable project search grounding") search_top_k = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Search top-k") gr.ChatInterface( fn=grounded_chat, additional_inputs=[system_prompt, max_new_tokens, temperature, enable_search, search_top_k], title=None, description=None, cache_examples=False, run_examples_on_click=False, examples=[ [ "Summarize why raw archive retrieval matters for exact-date memory questions.", DEFAULT_SYSTEM_PROMPT, 256, 0.2, True, 1, ], [ "Explain the difference between summary memory and fact memory.", DEFAULT_SYSTEM_PROMPT, 256, 0.2, True, 1, ], [ "Give me a concise evaluation rubric for a memory harness.", DEFAULT_SYSTEM_PROMPT, 256, 0.2, True, 1, ], ], ) with gr.Tab("Search"): gr.Markdown("Search the bundled project notes directly and inspect what the model can use for grounding.") search_query = gr.Textbox(label="Search query", lines=2, placeholder="Ask about raw archive, summary memory, evaluation, or architecture.") search_results_top_k = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Top-k results") search_button = gr.Button("Run search", variant="primary") search_markdown = gr.Markdown(value="No search results yet.") search_raw = gr.Textbox(label="Raw result JSON", lines=12) search_button.click(run_search, inputs=[search_query, search_results_top_k], outputs=[search_markdown, search_raw], api_name="search") if __name__ == "__main__": demo.launch()