from __future__ import annotations import json import os import re from functools import lru_cache from pathlib import Path import gradio as gr import pandas as pd APP_DIR = Path(__file__).resolve().parent RELEASE_MODEL_NAME = "PRISM-Memory 7B Adapter" BASE_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" MODEL_REPO_ID = os.environ.get("PRISM_MODEL_REPO", "AsadIsmail/prism-memory") SYSTEM_PROMPT = ( "You are a memory extraction assistant. Given a conversation turn, " "extract 0-5 atomic, standalone facts. Each fact must be a complete " "sentence about a specific person, event, preference, or property. " "Include dates/times when mentioned. Skip greetings, filler, and questions. " 'Output ONLY a JSON array of strings, e.g. ["fact1", "fact2"] or [].' ) TURN_PATTERN = re.compile( r"^\s*(?:\[(?P[^\]]+)\]\s*)?" r"(?:(?P\d{4}-\d{2}-\d{2})\s+)?" r"(?P[^:]{1,40}):\s*(?P.+?)\s*$" ) FILLER_PREFIX = re.compile( r"^(yeah|yep|ok|okay|well|so|honestly|basically|actually|good point|i think|we should probably)\b[:, -]*", re.IGNORECASE, ) FIRST_PERSON_PATTERNS = [ (re.compile(r"\bI'm\b", re.IGNORECASE), "{speaker} is"), (re.compile(r"\bI am\b", re.IGNORECASE), "{speaker} is"), (re.compile(r"\bI have\b", re.IGNORECASE), "{speaker} has"), (re.compile(r"\bI want\b", re.IGNORECASE), "{speaker} wants"), (re.compile(r"\bI need\b", re.IGNORECASE), "{speaker} needs"), (re.compile(r"\bI started\b", re.IGNORECASE), "{speaker} started"), (re.compile(r"\bI bought\b", re.IGNORECASE), "{speaker} bought"), (re.compile(r"\bI signed up\b", re.IGNORECASE), "{speaker} signed up"), (re.compile(r"\bmy\b", re.IGNORECASE), "{speaker}'s"), (re.compile(r"\bme\b", re.IGNORECASE), "{speaker}"), (re.compile(r"\bI\b", re.IGNORECASE), "{speaker}"), ] LOCOMO_CATEGORY_NAMES = { "1": "factual", "2": "temporal", "3": "inferential", "4": "multi-hop", "5": "adversarial", } LME_CATEGORY_ORDER = [ "knowledge-update", "multi-session", "single-session-assistant", "single-session-preference", "single-session-user", "temporal-reasoning", ] def _resolve_root() -> Path: for candidate in (APP_DIR, APP_DIR.parent): if (candidate / "results" / "release_summary.json").exists(): return candidate if (candidate / "docs" / "release" / "extraction-skill.md").exists(): return candidate if (candidate / "MEMORY_EXTRACTION_SKILL.md").exists(): return candidate return APP_DIR.parent ROOT = _resolve_root() RESULTS_DIR = ROOT / "results" SUMMARY_CANDIDATES = [RESULTS_DIR / "release_summary.json"] EXAMPLE_CANDIDATES = [RESULTS_DIR / "extraction_examples.json"] TRY_IT_CANDIDATES = [RESULTS_DIR / "try_it_sessions.json"] SKILL_CANDIDATES = [ ROOT / "docs" / "release" / "extraction-skill.md", ROOT / "MEMORY_EXTRACTION_SKILL.md", ] DATASET_CANDIDATES = [ ROOT / "docs" / "release" / "datasets.md", ROOT / "DATASETS.md", ] def _load_json(path: Path, default): if not path.exists(): return default return json.loads(path.read_text()) def _load_json_from_candidates(candidates: list[Path], default): for path in candidates: if path.exists(): return _load_json(path, default) return default def _clean_markdown(text: str) -> str: lines = text.splitlines() if lines and lines[0].startswith("[Back to Repo]"): lines = lines[1:] while lines and not lines[0].strip(): lines = lines[1:] return "\n".join(lines).strip() def _load_markdown(candidates: list[Path], fallback: str) -> str: for path in candidates: if path.exists(): return _clean_markdown(path.read_text()) return fallback def _load_summary() -> dict: return _load_json_from_candidates(SUMMARY_CANDIDATES, {"results": [], "failures": []}) def _load_examples() -> dict: return _load_json_from_candidates(EXAMPLE_CANDIDATES, {"examples": []}) def _load_try_it_examples() -> dict: return _load_json_from_candidates(TRY_IT_CANDIDATES, {"examples": []}) def _load_skill() -> str: return _load_markdown(SKILL_CANDIDATES, "Skill document not found.") def _load_datasets() -> str: return _load_markdown(DATASET_CANDIDATES, "Dataset summary not found.") def _best_result() -> dict | None: results = _load_summary().get("results", []) return results[0] if results else None def _model_name(item: dict) -> str: return item.get("model_name", RELEASE_MODEL_NAME) def _base_model(item: dict) -> str: return item.get("base_model", BASE_MODEL_NAME) def release_markdown() -> str: item = _best_result() if not item: return "## No confirmed release result yet" locomo = item["locomo"]["mean"] lme = item["lme"]["mean"] return "\n".join( [ "# PRISM-Memory", "", "**Turn conversations into durable, searchable memory.**", "", f"Released model: `{_model_name(item)}`", f"Base model: `{_base_model(item)}`", "", "| Benchmark | PRISM-Memory | GPT-4.1-based PropMem reference |", "|---|---:|---:|", f"| LongMemEval | `{lme:.3f}` | `0.465` |", f"| LoCoMo | `{locomo:.3f}` | `0.536` |", "", "This Space shows the public release in a product-shaped way: one model, an interactive try-it flow, held-out extraction examples, the synthetic-data summary, and the canonical extraction skill.", ] ) def summary_df() -> pd.DataFrame: item = _best_result() if not item: return pd.DataFrame(columns=["model", "base_model", "locomo_mean", "lme_mean", "cache_hits", "cache_misses", "eval_minutes"]) return pd.DataFrame( [ { "model": _model_name(item), "base_model": _base_model(item), "locomo_mean": round(item["locomo"]["mean"], 3), "lme_mean": round(item["lme"]["mean"], 3), "cache_hits": item["qa_cache"]["hits"], "cache_misses": item["qa_cache"]["misses"], "eval_minutes": item["elapsed_min"], } ] ) def category_df() -> pd.DataFrame: item = _best_result() if not item: return pd.DataFrame(columns=["benchmark", "category", "score"]) rows = [] for category in sorted(item["locomo"]["categories"], key=int): rows.append( { "benchmark": "LoCoMo", "category": LOCOMO_CATEGORY_NAMES.get(category, category), "score": round(item["locomo"]["categories"][category], 3), } ) for category in LME_CATEGORY_ORDER: if category in item["lme"]["categories"]: rows.append( { "benchmark": "LongMemEval", "category": category, "score": round(item["lme"]["categories"][category], 3), } ) return pd.DataFrame(rows) def _example_label(item: dict) -> str: return item["title"] def example_choices() -> list[str]: examples = _load_examples().get("examples", []) return [_example_label(example) for example in examples] def render_example(choice: str) -> str: examples = _load_examples().get("examples", []) if not examples: return "No extraction examples available yet." item = next( (example for example in examples if _example_label(example) == choice or example["id"] == choice), examples[0], ) body = [ f"### {item['title']}", "", f"**Session date:** `{item['session_date']}`", f"**Overlap score:** `{item['overlap_score']:.3f}`", f"**What this example shows:** {item['note']}", "", "**Turn**", "", f"> {item['user_message']}", "", "**GPT-4.1 reference**", ] body.extend([f"- {entry}" for entry in item.get("gpt41_reference", [])]) body.extend(["", "**PRISM-Memory**"]) body.extend([f"- {entry}" for entry in item.get("prism_memory", [])]) return "\n".join(body) def _session_label(item: dict) -> str: return item["title"] def try_it_choices() -> list[str]: sessions = _load_try_it_examples().get("examples", []) return [_session_label(item) for item in sessions] def _get_session(choice: str | None) -> dict | None: sessions = _load_try_it_examples().get("examples", []) if not sessions: return None if not choice: return sessions[0] return next( (item for item in sessions if _session_label(item) == choice or item["id"] == choice), sessions[0], ) def load_try_it_session(choice: str): item = _get_session(choice) if not item: return "", "No bundled example sessions available." intro = "\n".join( [ f"### {item['title']}", "", f"**What this session shows:** {item['note']}", "", "**Later question**", "", item["later_question"], ] ) return item["transcript"], intro def load_and_run_session(choice: str): transcript, intro = load_try_it_session(choice) status, table, memory_md, qa_md = run_try_it(choice, transcript) return transcript, intro, status, table, memory_md, qa_md def _parse_transcript(transcript: str) -> list[dict[str, str]]: turns: list[dict[str, str]] = [] for raw_line in transcript.splitlines(): line = raw_line.strip() if not line: continue match = TURN_PATTERN.match(line) if match: turns.append( { "date": (match.group("bracket_date") or match.group("plain_date") or "").strip(), "speaker": match.group("speaker").strip(), "text": match.group("text").strip(), } ) continue if turns: turns[-1]["text"] = f"{turns[-1]['text']} {line}".strip() return turns def _normalize_first_person(text: str, speaker: str) -> str: value = text for pattern, replacement in FIRST_PERSON_PATTERNS: value = pattern.sub(replacement.format(speaker=speaker), value) return value def _clean_clause(text: str) -> str: text = FILLER_PREFIX.sub("", text).strip() text = text.strip(" -\t") return re.sub(r"\s+", " ", text) def _preview_extract_turn(turn: dict[str, str], _: str) -> list[str]: pieces = re.split(r"[.;]\s+|\n+", turn["text"]) facts: list[str] = [] for piece in pieces: clause = _clean_clause(piece) if not clause or clause.endswith("?") or len(clause) < 8: continue clause = _normalize_first_person(clause, turn["speaker"]) if clause.lower().startswith(("hi ", "hello ", "thanks ", "thank you ")): continue clause = clause[0].upper() + clause[1:] if not clause.endswith((".", "!", "?")): clause += "." facts.append(clause) if len(facts) == 5: break return facts def _parse_props(raw: str) -> list[str]: try: parsed = json.loads(raw) if isinstance(parsed, list): return [str(item).strip() for item in parsed if str(item).strip()] except Exception: pass match = re.search(r"\[([^\]]*)\]", raw, re.DOTALL) if match: try: parsed = json.loads("[" + match.group(1) + "]") if isinstance(parsed, list): return [str(item).strip() for item in parsed if str(item).strip()] except Exception: pass return _preview_extract_turn({"speaker": "Speaker", "text": raw}, "") @lru_cache(maxsize=1) def _load_live_stack(): if os.environ.get("PRISM_ENABLE_LIVE_MODEL", "").lower() not in {"1", "true", "yes"}: raise RuntimeError("Live model loading is disabled for this Space runtime.") try: import torch from peft import AutoPeftModelForCausalLM from transformers import AutoTokenizer except Exception as exc: # pragma: no cover - dependency gate raise RuntimeError("Live model dependencies are not installed in this runtime.") from exc tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID, trust_remote_code=True) tokenizer.padding_side = "left" device_map = "auto" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoPeftModelForCausalLM.from_pretrained( MODEL_REPO_ID, trust_remote_code=True, device_map=device_map, low_cpu_mem_usage=True, torch_dtype=torch_dtype, ) model.eval() return torch, tokenizer, model def _live_extract_turn(turn: dict[str, str], context: str) -> list[str]: torch_mod, tokenizer, model = _load_live_stack() user_prompt = "\n".join( [ f"Date: {turn['date']}" if turn.get("date") else "", f"Recent context: ...{context[-300:]}" if context else "", f"Speaker ({turn['speaker']}): {turn['text']}", ] ).strip() messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) encoded = tokenizer( [prompt], return_tensors="pt", padding=True, truncation=True, max_length=768, ).to(model.device) with torch_mod.inference_mode(): output = model.generate( **encoded, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=False, ) raw = tokenizer.decode(output[0][encoded.input_ids.shape[1] :], skip_special_tokens=True).strip() return _parse_props(raw)[:5] def _dedupe_items(values: list[str]) -> list[str]: seen: set[str] = set() output: list[str] = [] for value in values: key = value.casefold().strip() if not key or key in seen: continue seen.add(key) output.append(value.strip()) return output def _memory_markdown(items: list[str]) -> str: if not items: return "No memory records written yet." return "\n".join(["### Accumulated Memory", ""] + [f"- {item}" for item in items]) def _question_markdown(item: dict | None) -> str: if not item: return "" return "\n".join( [ "### Later Question", "", f"**Question:** {item['later_question']}", f"**Answer from memory:** {item['answer_from_memory']}", ] ) def run_try_it(choice: str, transcript: str): session = _get_session(choice) normalized_transcript = transcript.strip() if not normalized_transcript: return "Paste a transcript or load one of the bundled sessions.", pd.DataFrame(), "No memory records yet.", "" if session and normalized_transcript == session["transcript"].strip(): rows = [] final_memory: list[str] = [] for turn in session["turns"]: extracted = turn.get("prism_memory", []) rows.append( { "turn": turn["turn_index"], "date": turn["date"], "speaker": turn["speaker"], "memory_records": "\n".join(extracted), } ) final_memory.extend( [f"[{turn['date']}] {record}" if turn.get("date") else record for record in extracted] ) status = "\n".join( [ "### Try It", "", "**Mode:** released model output (bundled example)", "", "These per-turn memory records were precomputed with the released PRISM-Memory adapter. The model is trained to write memory turn by turn, then let retrieval use the accumulated store later.", ] ) return status, pd.DataFrame(rows), _memory_markdown(_dedupe_items(final_memory)), _question_markdown(session) turns = _parse_transcript(normalized_transcript) if not turns: return ( "Could not parse the transcript. Use one turn per line, for example `[2025-03-01] Dana: We have 20 concurrent jobs max.`", pd.DataFrame(), "No memory records yet.", "", ) extractor = _preview_extract_turn mode = "preview extractor" note = ( "This runtime is using a lightweight turn-by-turn preview that follows the same extraction contract. " "The bundled example sessions above use actual released-model outputs." ) try: _load_live_stack() extractor = _live_extract_turn mode = "released model (live)" note = "This runtime successfully loaded the released adapter and is extracting memory turn by turn." except Exception: pass rows = [] final_memory: list[str] = [] context_lines: list[str] = [] for index, turn in enumerate(turns, start=1): context = "\n".join(context_lines[-6:]) extracted = extractor(turn, context)[:5] rows.append( { "turn": index, "date": turn.get("date", ""), "speaker": turn["speaker"], "memory_records": "\n".join(extracted), } ) final_memory.extend( [f"[{turn['date']}] {record}" if turn.get("date") else record for record in extracted] ) context_lines.append(f"[{turn.get('date', '')}] {turn['speaker']}: {turn['text']}") status = "\n".join( [ "### Try It", "", f"**Mode:** {mode}", "", note, "", "Expected transcript format: one turn per line as `[YYYY-MM-DD] Speaker: message` or `Speaker: message`.", ] ) return status, pd.DataFrame(rows), _memory_markdown(_dedupe_items(final_memory)), "" INITIAL_TRY_IT_CHOICES = try_it_choices() INITIAL_TRY_IT_CHOICE = INITIAL_TRY_IT_CHOICES[0] if INITIAL_TRY_IT_CHOICES else "" INITIAL_TRY_IT_TRANSCRIPT, INITIAL_TRY_IT_INTRO = load_try_it_session(INITIAL_TRY_IT_CHOICE) if INITIAL_TRY_IT_CHOICE else ("", "No bundled example sessions available.") INITIAL_TRY_IT_STATUS, INITIAL_TRY_IT_DF, INITIAL_TRY_IT_MEMORY, INITIAL_TRY_IT_QA = ( run_try_it(INITIAL_TRY_IT_CHOICE, INITIAL_TRY_IT_TRANSCRIPT) if INITIAL_TRY_IT_CHOICE else ("No bundled example sessions available.", pd.DataFrame(), "No memory records yet.", "") ) INITIAL_EXAMPLE_CHOICES = example_choices() INITIAL_EXAMPLE_CHOICE = INITIAL_EXAMPLE_CHOICES[0] if INITIAL_EXAMPLE_CHOICES else "pending" INITIAL_EXAMPLE_MD = render_example(INITIAL_EXAMPLE_CHOICE) if INITIAL_EXAMPLE_CHOICES else "No extraction examples available yet." with gr.Blocks(title="PRISM-Memory Demo") as demo: gr.Markdown(release_markdown()) with gr.Tab("Metrics"): gr.Markdown("## Released Model") metrics = gr.Dataframe(value=summary_df(), interactive=False, wrap=True) gr.Markdown("## Category Breakdown") categories = gr.Dataframe(value=category_df(), interactive=False, wrap=True) refresh = gr.Button("Refresh Data") refresh.click(fn=lambda: (summary_df(), category_df()), outputs=[metrics, categories]) with gr.Tab("Try It"): gr.Markdown( "\n".join( [ "Use one of the bundled sessions or paste your own transcript.", "", "PRISM-Memory is trained to write memory **turn by turn**, not to summarize a whole session in one shot.", ] ) ) choices = INITIAL_TRY_IT_CHOICES or ["No bundled sessions"] session_picker = gr.Dropdown(choices=choices, value=choices[0], label="Example Session") session_intro = gr.Markdown(value=INITIAL_TRY_IT_INTRO) transcript_box = gr.Textbox( label="Transcript", lines=10, value=INITIAL_TRY_IT_TRANSCRIPT, placeholder="[2025-03-01] Dana: We have 20 concurrent jobs max on GitHub Actions right now.", ) run_button = gr.Button("Extract Memory") try_it_status = gr.Markdown(value=INITIAL_TRY_IT_STATUS) per_turn_df = gr.Dataframe(value=INITIAL_TRY_IT_DF, interactive=False, wrap=True, label="Per-Turn Memory") final_memory_md = gr.Markdown(value=INITIAL_TRY_IT_MEMORY) later_question_md = gr.Markdown(value=INITIAL_TRY_IT_QA) session_picker.change( load_and_run_session, inputs=session_picker, outputs=[transcript_box, session_intro, try_it_status, per_turn_df, final_memory_md, later_question_md], ) run_button.click( run_try_it, inputs=[session_picker, transcript_box], outputs=[try_it_status, per_turn_df, final_memory_md, later_question_md], ) with gr.Tab("Extraction Examples"): choices = INITIAL_EXAMPLE_CHOICES or ["pending"] picker = gr.Dropdown(choices=choices, value=choices[0], label="Held-Out Example") example_md = gr.Markdown(value=INITIAL_EXAMPLE_MD) picker.change(render_example, inputs=picker, outputs=example_md) with gr.Tab("Data"): gr.Markdown(_load_datasets()) with gr.Tab("Skill"): gr.Markdown(_load_skill()) if __name__ == "__main__": demo.launch()