Spaces:
Sleeping
Sleeping
| """Gradio Space app for Gemma 4 text chat.""" | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from collections.abc import Iterator | |
| from pathlib import Path | |
| from threading import Thread | |
| import gradio as gr | |
| import torch | |
| SPACE_APP_DIR = Path(__file__).resolve().parent | |
| if str(SPACE_APP_DIR) not in sys.path: | |
| sys.path.insert(0, str(SPACE_APP_DIR)) | |
| from search_backend import ( | |
| format_search_grounding, | |
| format_search_markdown, | |
| search_project_notes, | |
| serialize_hits, | |
| ) | |
| try: | |
| import spaces | |
| except ImportError: # pragma: no cover | |
| class _SpacesCompat: | |
| def GPU(duration: int | None = None): | |
| def _decorator(fn): | |
| return fn | |
| return _decorator | |
| spaces = _SpacesCompat() | |
| MODEL_ID = "google/gemma-4-e4b-it" | |
| MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10000")) | |
| DEFAULT_SYSTEM_PROMPT = ( | |
| "You are a precise assistant helping evaluate a local memory harness project. " | |
| "Answer clearly and briefly." | |
| ) | |
| processor = None | |
| model = None | |
| LOAD_ERROR = None | |
| def _ensure_model_loaded() -> None: | |
| global processor, model, LOAD_ERROR | |
| if processor is not None and model is not None: | |
| return | |
| if os.getenv("SPACE_DISABLE_MODEL_LOAD") == "1": | |
| LOAD_ERROR = "Model loading disabled for local test mode." | |
| return | |
| try: | |
| from transformers import AutoProcessor | |
| try: | |
| from transformers import AutoModelForMultimodalLM as ModelLoader | |
| except ImportError: | |
| from transformers import AutoModelForImageTextToText as ModelLoader | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False) | |
| model = ModelLoader.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| dtype=torch.bfloat16, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| LOAD_ERROR = str(exc) | |
| def _build_messages(message: str, history: list, system_prompt: str) -> list[dict]: | |
| messages: list[dict] = [] | |
| if system_prompt.strip(): | |
| messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]}) | |
| for item in history: | |
| if isinstance(item, dict): | |
| role = item["role"] | |
| content = item["content"] | |
| if isinstance(content, list): | |
| text = " ".join(part.get("text", "") for part in content if isinstance(part, dict)) | |
| else: | |
| text = str(content) | |
| else: | |
| user_text, assistant_text = item | |
| messages.append({"role": "user", "content": [{"type": "text", "text": str(user_text)}]}) | |
| if assistant_text: | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": str(assistant_text)}]}) | |
| continue | |
| if role == "user": | |
| messages.append({"role": "user", "content": [{"type": "text", "text": text}]}) | |
| else: | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": text}]}) | |
| messages.append({"role": "user", "content": [{"type": "text", "text": message}]}) | |
| return messages | |
| def _inject_search_context(message: str, system_prompt: str, enable_search: bool, search_top_k: int) -> tuple[str, list]: | |
| if not enable_search: | |
| return system_prompt, [] | |
| hits = search_project_notes(message, top_k=search_top_k) | |
| grounding = format_search_grounding(hits) | |
| if not grounding: | |
| return system_prompt, hits | |
| if system_prompt.strip(): | |
| return f"{system_prompt.strip()}\n\n{grounding}", hits | |
| return grounding, hits | |
| def _stream_generate(messages: list[dict], max_new_tokens: int, temperature: float) -> Iterator[str]: | |
| _ensure_model_loaded() | |
| if processor is None or model is None: | |
| raise gr.Error(f"Gemma 4 model is not loaded. {LOAD_ERROR or 'Unknown load failure.'}") | |
| from transformers.generation.streamers import TextIteratorStreamer | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| ) | |
| n_tokens = inputs["input_ids"].shape[1] | |
| if n_tokens > MAX_INPUT_TOKENS: | |
| raise gr.Error(f"Input too long ({n_tokens} tokens). Maximum is {MAX_INPUT_TOKENS}.") | |
| inputs = inputs.to(device=model.device, dtype=torch.bfloat16) | |
| streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) | |
| generation_temperature = max(float(temperature), 1e-5) | |
| kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": generation_temperature, | |
| "do_sample": temperature > 0, | |
| "disable_compile": True, | |
| } | |
| errors: list[Exception] = [] | |
| def _run() -> None: | |
| try: | |
| model.generate(**kwargs) | |
| except Exception as exc: # noqa: BLE001 | |
| errors.append(exc) | |
| thread = Thread(target=_run) | |
| thread.start() | |
| chunks: list[str] = [] | |
| for chunk in streamer: | |
| chunks.append(chunk) | |
| yield "".join(chunks) | |
| thread.join() | |
| if errors: | |
| raise gr.Error(f"Generation failed: {errors[0]}") | |
| def chat(message: str, history: list[dict], system_prompt: str, max_new_tokens: int, temperature: float): | |
| if not message.strip(): | |
| raise gr.Error("Please enter a message.") | |
| messages = _build_messages(message, history, system_prompt) | |
| yield from _stream_generate(messages, max_new_tokens=max_new_tokens, temperature=temperature) | |
| def grounded_chat( | |
| message: str, | |
| history: list[dict], | |
| system_prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| enable_search: bool, | |
| search_top_k: int, | |
| ): | |
| if not message.strip(): | |
| raise gr.Error("Please enter a message.") | |
| grounded_prompt, _ = _inject_search_context(message, system_prompt, enable_search, search_top_k) | |
| messages = _build_messages(message, history, grounded_prompt) | |
| yield from _stream_generate(messages, max_new_tokens=max_new_tokens, temperature=temperature) | |
| def run_search(query: str, top_k: int) -> tuple[str, str]: | |
| hits = search_project_notes(query, top_k=top_k) | |
| return format_search_markdown(hits), serialize_hits(hits) | |
| if os.getenv("SPACE_DISABLE_MODEL_LOAD") != "1": | |
| _ensure_model_loaded() | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Memory Harness Gemma 4 | |
| Text-only Hugging Face Space using `google/gemma-4-E4B-it`. | |
| This Space is meant to validate that the repository can be paired with a hosted Hugging Face Gemma 4 demo UI. | |
| It now includes lightweight project-note search so the model can ground answers before responding. | |
| """ | |
| ) | |
| with gr.Tab("Chat"): | |
| with gr.Row(): | |
| system_prompt = gr.Textbox( | |
| value=DEFAULT_SYSTEM_PROMPT, | |
| label="System prompt", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(minimum=64, maximum=1024, value=256, step=32, label="Max new tokens") | |
| temperature = gr.Slider(minimum=0.0, maximum=1.2, value=0.2, step=0.1, label="Temperature") | |
| with gr.Row(): | |
| enable_search = gr.Checkbox(value=True, label="Enable project search grounding") | |
| search_top_k = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Search top-k") | |
| gr.ChatInterface( | |
| fn=grounded_chat, | |
| additional_inputs=[system_prompt, max_new_tokens, temperature, enable_search, search_top_k], | |
| title=None, | |
| description=None, | |
| cache_examples=False, | |
| run_examples_on_click=False, | |
| examples=[ | |
| [ | |
| "Summarize why raw archive retrieval matters for exact-date memory questions.", | |
| DEFAULT_SYSTEM_PROMPT, | |
| 256, | |
| 0.2, | |
| True, | |
| 1, | |
| ], | |
| [ | |
| "Explain the difference between summary memory and fact memory.", | |
| DEFAULT_SYSTEM_PROMPT, | |
| 256, | |
| 0.2, | |
| True, | |
| 1, | |
| ], | |
| [ | |
| "Give me a concise evaluation rubric for a memory harness.", | |
| DEFAULT_SYSTEM_PROMPT, | |
| 256, | |
| 0.2, | |
| True, | |
| 1, | |
| ], | |
| ], | |
| ) | |
| with gr.Tab("Search"): | |
| gr.Markdown("Search the bundled project notes directly and inspect what the model can use for grounding.") | |
| search_query = gr.Textbox(label="Search query", lines=2, placeholder="Ask about raw archive, summary memory, evaluation, or architecture.") | |
| search_results_top_k = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Top-k results") | |
| search_button = gr.Button("Run search", variant="primary") | |
| search_markdown = gr.Markdown(value="No search results yet.") | |
| search_raw = gr.Textbox(label="Raw result JSON", lines=12) | |
| search_button.click(run_search, inputs=[search_query, search_results_top_k], outputs=[search_markdown, search_raw], api_name="search") | |
| if __name__ == "__main__": | |
| demo.launch() | |