Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from functools import lru_cache | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| APP_DIR = Path(__file__).resolve().parent | |
| RELEASE_MODEL_NAME = "PRISM-Memory 7B Adapter" | |
| BASE_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" | |
| MODEL_REPO_ID = os.environ.get("PRISM_MODEL_REPO", "AsadIsmail/prism-memory") | |
| SYSTEM_PROMPT = ( | |
| "You are a memory extraction assistant. Given a conversation turn, " | |
| "extract 0-5 atomic, standalone facts. Each fact must be a complete " | |
| "sentence about a specific person, event, preference, or property. " | |
| "Include dates/times when mentioned. Skip greetings, filler, and questions. " | |
| 'Output ONLY a JSON array of strings, e.g. ["fact1", "fact2"] or [].' | |
| ) | |
| TURN_PATTERN = re.compile( | |
| r"^\s*(?:\[(?P<bracket_date>[^\]]+)\]\s*)?" | |
| r"(?:(?P<plain_date>\d{4}-\d{2}-\d{2})\s+)?" | |
| r"(?P<speaker>[^:]{1,40}):\s*(?P<text>.+?)\s*$" | |
| ) | |
| FILLER_PREFIX = re.compile( | |
| r"^(yeah|yep|ok|okay|well|so|honestly|basically|actually|good point|i think|we should probably)\b[:, -]*", | |
| re.IGNORECASE, | |
| ) | |
| FIRST_PERSON_PATTERNS = [ | |
| (re.compile(r"\bI'm\b", re.IGNORECASE), "{speaker} is"), | |
| (re.compile(r"\bI am\b", re.IGNORECASE), "{speaker} is"), | |
| (re.compile(r"\bI have\b", re.IGNORECASE), "{speaker} has"), | |
| (re.compile(r"\bI want\b", re.IGNORECASE), "{speaker} wants"), | |
| (re.compile(r"\bI need\b", re.IGNORECASE), "{speaker} needs"), | |
| (re.compile(r"\bI started\b", re.IGNORECASE), "{speaker} started"), | |
| (re.compile(r"\bI bought\b", re.IGNORECASE), "{speaker} bought"), | |
| (re.compile(r"\bI signed up\b", re.IGNORECASE), "{speaker} signed up"), | |
| (re.compile(r"\bmy\b", re.IGNORECASE), "{speaker}'s"), | |
| (re.compile(r"\bme\b", re.IGNORECASE), "{speaker}"), | |
| (re.compile(r"\bI\b", re.IGNORECASE), "{speaker}"), | |
| ] | |
| LOCOMO_CATEGORY_NAMES = { | |
| "1": "factual", | |
| "2": "temporal", | |
| "3": "inferential", | |
| "4": "multi-hop", | |
| "5": "adversarial", | |
| } | |
| LME_CATEGORY_ORDER = [ | |
| "knowledge-update", | |
| "multi-session", | |
| "single-session-assistant", | |
| "single-session-preference", | |
| "single-session-user", | |
| "temporal-reasoning", | |
| ] | |
| def _resolve_root() -> Path: | |
| for candidate in (APP_DIR, APP_DIR.parent): | |
| if (candidate / "results" / "release_summary.json").exists(): | |
| return candidate | |
| if (candidate / "docs" / "release" / "extraction-skill.md").exists(): | |
| return candidate | |
| if (candidate / "MEMORY_EXTRACTION_SKILL.md").exists(): | |
| return candidate | |
| return APP_DIR.parent | |
| ROOT = _resolve_root() | |
| RESULTS_DIR = ROOT / "results" | |
| SUMMARY_CANDIDATES = [RESULTS_DIR / "release_summary.json"] | |
| EXAMPLE_CANDIDATES = [RESULTS_DIR / "extraction_examples.json"] | |
| TRY_IT_CANDIDATES = [RESULTS_DIR / "try_it_sessions.json"] | |
| SKILL_CANDIDATES = [ | |
| ROOT / "docs" / "release" / "extraction-skill.md", | |
| ROOT / "MEMORY_EXTRACTION_SKILL.md", | |
| ] | |
| DATASET_CANDIDATES = [ | |
| ROOT / "docs" / "release" / "datasets.md", | |
| ROOT / "DATASETS.md", | |
| ] | |
| def _load_json(path: Path, default): | |
| if not path.exists(): | |
| return default | |
| return json.loads(path.read_text()) | |
| def _load_json_from_candidates(candidates: list[Path], default): | |
| for path in candidates: | |
| if path.exists(): | |
| return _load_json(path, default) | |
| return default | |
| def _clean_markdown(text: str) -> str: | |
| lines = text.splitlines() | |
| if lines and lines[0].startswith("[Back to Repo]"): | |
| lines = lines[1:] | |
| while lines and not lines[0].strip(): | |
| lines = lines[1:] | |
| return "\n".join(lines).strip() | |
| def _load_markdown(candidates: list[Path], fallback: str) -> str: | |
| for path in candidates: | |
| if path.exists(): | |
| return _clean_markdown(path.read_text()) | |
| return fallback | |
| def _load_summary() -> dict: | |
| return _load_json_from_candidates(SUMMARY_CANDIDATES, {"results": [], "failures": []}) | |
| def _load_examples() -> dict: | |
| return _load_json_from_candidates(EXAMPLE_CANDIDATES, {"examples": []}) | |
| def _load_try_it_examples() -> dict: | |
| return _load_json_from_candidates(TRY_IT_CANDIDATES, {"examples": []}) | |
| def _load_skill() -> str: | |
| return _load_markdown(SKILL_CANDIDATES, "Skill document not found.") | |
| def _load_datasets() -> str: | |
| return _load_markdown(DATASET_CANDIDATES, "Dataset summary not found.") | |
| def _best_result() -> dict | None: | |
| results = _load_summary().get("results", []) | |
| return results[0] if results else None | |
| def _model_name(item: dict) -> str: | |
| return item.get("model_name", RELEASE_MODEL_NAME) | |
| def _base_model(item: dict) -> str: | |
| return item.get("base_model", BASE_MODEL_NAME) | |
| def release_markdown() -> str: | |
| item = _best_result() | |
| if not item: | |
| return "## No confirmed release result yet" | |
| locomo = item["locomo"]["mean"] | |
| lme = item["lme"]["mean"] | |
| return "\n".join( | |
| [ | |
| "# PRISM-Memory", | |
| "", | |
| "**Turn conversations into durable, searchable memory.**", | |
| "", | |
| f"Released model: `{_model_name(item)}`", | |
| f"Base model: `{_base_model(item)}`", | |
| "", | |
| "| Benchmark | PRISM-Memory | GPT-4.1-based PropMem reference |", | |
| "|---|---:|---:|", | |
| f"| LongMemEval | `{lme:.3f}` | `0.465` |", | |
| f"| LoCoMo | `{locomo:.3f}` | `0.536` |", | |
| "", | |
| "This Space shows the public release in a product-shaped way: one model, an interactive try-it flow, held-out extraction examples, the synthetic-data summary, and the canonical extraction skill.", | |
| ] | |
| ) | |
| def summary_df() -> pd.DataFrame: | |
| item = _best_result() | |
| if not item: | |
| return pd.DataFrame(columns=["model", "base_model", "locomo_mean", "lme_mean", "cache_hits", "cache_misses", "eval_minutes"]) | |
| return pd.DataFrame( | |
| [ | |
| { | |
| "model": _model_name(item), | |
| "base_model": _base_model(item), | |
| "locomo_mean": round(item["locomo"]["mean"], 3), | |
| "lme_mean": round(item["lme"]["mean"], 3), | |
| "cache_hits": item["qa_cache"]["hits"], | |
| "cache_misses": item["qa_cache"]["misses"], | |
| "eval_minutes": item["elapsed_min"], | |
| } | |
| ] | |
| ) | |
| def category_df() -> pd.DataFrame: | |
| item = _best_result() | |
| if not item: | |
| return pd.DataFrame(columns=["benchmark", "category", "score"]) | |
| rows = [] | |
| for category in sorted(item["locomo"]["categories"], key=int): | |
| rows.append( | |
| { | |
| "benchmark": "LoCoMo", | |
| "category": LOCOMO_CATEGORY_NAMES.get(category, category), | |
| "score": round(item["locomo"]["categories"][category], 3), | |
| } | |
| ) | |
| for category in LME_CATEGORY_ORDER: | |
| if category in item["lme"]["categories"]: | |
| rows.append( | |
| { | |
| "benchmark": "LongMemEval", | |
| "category": category, | |
| "score": round(item["lme"]["categories"][category], 3), | |
| } | |
| ) | |
| return pd.DataFrame(rows) | |
| def _example_label(item: dict) -> str: | |
| return item["title"] | |
| def example_choices() -> list[str]: | |
| examples = _load_examples().get("examples", []) | |
| return [_example_label(example) for example in examples] | |
| def render_example(choice: str) -> str: | |
| examples = _load_examples().get("examples", []) | |
| if not examples: | |
| return "No extraction examples available yet." | |
| item = next( | |
| (example for example in examples if _example_label(example) == choice or example["id"] == choice), | |
| examples[0], | |
| ) | |
| body = [ | |
| f"### {item['title']}", | |
| "", | |
| f"**Session date:** `{item['session_date']}`", | |
| f"**Overlap score:** `{item['overlap_score']:.3f}`", | |
| f"**What this example shows:** {item['note']}", | |
| "", | |
| "**Turn**", | |
| "", | |
| f"> {item['user_message']}", | |
| "", | |
| "**GPT-4.1 reference**", | |
| ] | |
| body.extend([f"- {entry}" for entry in item.get("gpt41_reference", [])]) | |
| body.extend(["", "**PRISM-Memory**"]) | |
| body.extend([f"- {entry}" for entry in item.get("prism_memory", [])]) | |
| return "\n".join(body) | |
| def _session_label(item: dict) -> str: | |
| return item["title"] | |
| def try_it_choices() -> list[str]: | |
| sessions = _load_try_it_examples().get("examples", []) | |
| return [_session_label(item) for item in sessions] | |
| def _get_session(choice: str | None) -> dict | None: | |
| sessions = _load_try_it_examples().get("examples", []) | |
| if not sessions: | |
| return None | |
| if not choice: | |
| return sessions[0] | |
| return next( | |
| (item for item in sessions if _session_label(item) == choice or item["id"] == choice), | |
| sessions[0], | |
| ) | |
| def load_try_it_session(choice: str): | |
| item = _get_session(choice) | |
| if not item: | |
| return "", "No bundled example sessions available." | |
| intro = "\n".join( | |
| [ | |
| f"### {item['title']}", | |
| "", | |
| f"**What this session shows:** {item['note']}", | |
| "", | |
| "**Later question**", | |
| "", | |
| item["later_question"], | |
| ] | |
| ) | |
| return item["transcript"], intro | |
| def load_and_run_session(choice: str): | |
| transcript, intro = load_try_it_session(choice) | |
| status, table, memory_md, qa_md = run_try_it(choice, transcript) | |
| return transcript, intro, status, table, memory_md, qa_md | |
| def _parse_transcript(transcript: str) -> list[dict[str, str]]: | |
| turns: list[dict[str, str]] = [] | |
| for raw_line in transcript.splitlines(): | |
| line = raw_line.strip() | |
| if not line: | |
| continue | |
| match = TURN_PATTERN.match(line) | |
| if match: | |
| turns.append( | |
| { | |
| "date": (match.group("bracket_date") or match.group("plain_date") or "").strip(), | |
| "speaker": match.group("speaker").strip(), | |
| "text": match.group("text").strip(), | |
| } | |
| ) | |
| continue | |
| if turns: | |
| turns[-1]["text"] = f"{turns[-1]['text']} {line}".strip() | |
| return turns | |
| def _normalize_first_person(text: str, speaker: str) -> str: | |
| value = text | |
| for pattern, replacement in FIRST_PERSON_PATTERNS: | |
| value = pattern.sub(replacement.format(speaker=speaker), value) | |
| return value | |
| def _clean_clause(text: str) -> str: | |
| text = FILLER_PREFIX.sub("", text).strip() | |
| text = text.strip(" -\t") | |
| return re.sub(r"\s+", " ", text) | |
| def _preview_extract_turn(turn: dict[str, str], _: str) -> list[str]: | |
| pieces = re.split(r"[.;]\s+|\n+", turn["text"]) | |
| facts: list[str] = [] | |
| for piece in pieces: | |
| clause = _clean_clause(piece) | |
| if not clause or clause.endswith("?") or len(clause) < 8: | |
| continue | |
| clause = _normalize_first_person(clause, turn["speaker"]) | |
| if clause.lower().startswith(("hi ", "hello ", "thanks ", "thank you ")): | |
| continue | |
| clause = clause[0].upper() + clause[1:] | |
| if not clause.endswith((".", "!", "?")): | |
| clause += "." | |
| facts.append(clause) | |
| if len(facts) == 5: | |
| break | |
| return facts | |
| def _parse_props(raw: str) -> list[str]: | |
| try: | |
| parsed = json.loads(raw) | |
| if isinstance(parsed, list): | |
| return [str(item).strip() for item in parsed if str(item).strip()] | |
| except Exception: | |
| pass | |
| match = re.search(r"\[([^\]]*)\]", raw, re.DOTALL) | |
| if match: | |
| try: | |
| parsed = json.loads("[" + match.group(1) + "]") | |
| if isinstance(parsed, list): | |
| return [str(item).strip() for item in parsed if str(item).strip()] | |
| except Exception: | |
| pass | |
| return _preview_extract_turn({"speaker": "Speaker", "text": raw}, "") | |
| def _load_live_stack(): | |
| if os.environ.get("PRISM_ENABLE_LIVE_MODEL", "").lower() not in {"1", "true", "yes"}: | |
| raise RuntimeError("Live model loading is disabled for this Space runtime.") | |
| try: | |
| import torch | |
| from peft import AutoPeftModelForCausalLM | |
| from transformers import AutoTokenizer | |
| except Exception as exc: # pragma: no cover - dependency gate | |
| raise RuntimeError("Live model dependencies are not installed in this runtime.") from exc | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID, trust_remote_code=True) | |
| tokenizer.padding_side = "left" | |
| device_map = "auto" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model = AutoPeftModelForCausalLM.from_pretrained( | |
| MODEL_REPO_ID, | |
| trust_remote_code=True, | |
| device_map=device_map, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch_dtype, | |
| ) | |
| model.eval() | |
| return torch, tokenizer, model | |
| def _live_extract_turn(turn: dict[str, str], context: str) -> list[str]: | |
| torch_mod, tokenizer, model = _load_live_stack() | |
| user_prompt = "\n".join( | |
| [ | |
| f"Date: {turn['date']}" if turn.get("date") else "", | |
| f"Recent context: ...{context[-300:]}" if context else "", | |
| f"Speaker ({turn['speaker']}): {turn['text']}", | |
| ] | |
| ).strip() | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| encoded = tokenizer( | |
| [prompt], | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=768, | |
| ).to(model.device) | |
| with torch_mod.inference_mode(): | |
| output = model.generate( | |
| **encoded, | |
| max_new_tokens=128, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| do_sample=False, | |
| ) | |
| raw = tokenizer.decode(output[0][encoded.input_ids.shape[1] :], skip_special_tokens=True).strip() | |
| return _parse_props(raw)[:5] | |
| def _dedupe_items(values: list[str]) -> list[str]: | |
| seen: set[str] = set() | |
| output: list[str] = [] | |
| for value in values: | |
| key = value.casefold().strip() | |
| if not key or key in seen: | |
| continue | |
| seen.add(key) | |
| output.append(value.strip()) | |
| return output | |
| def _memory_markdown(items: list[str]) -> str: | |
| if not items: | |
| return "No memory records written yet." | |
| return "\n".join(["### Accumulated Memory", ""] + [f"- {item}" for item in items]) | |
| def _question_markdown(item: dict | None) -> str: | |
| if not item: | |
| return "" | |
| return "\n".join( | |
| [ | |
| "### Later Question", | |
| "", | |
| f"**Question:** {item['later_question']}", | |
| f"**Answer from memory:** {item['answer_from_memory']}", | |
| ] | |
| ) | |
| def run_try_it(choice: str, transcript: str): | |
| session = _get_session(choice) | |
| normalized_transcript = transcript.strip() | |
| if not normalized_transcript: | |
| return "Paste a transcript or load one of the bundled sessions.", pd.DataFrame(), "No memory records yet.", "" | |
| if session and normalized_transcript == session["transcript"].strip(): | |
| rows = [] | |
| final_memory: list[str] = [] | |
| for turn in session["turns"]: | |
| extracted = turn.get("prism_memory", []) | |
| rows.append( | |
| { | |
| "turn": turn["turn_index"], | |
| "date": turn["date"], | |
| "speaker": turn["speaker"], | |
| "memory_records": "\n".join(extracted), | |
| } | |
| ) | |
| final_memory.extend( | |
| [f"[{turn['date']}] {record}" if turn.get("date") else record for record in extracted] | |
| ) | |
| status = "\n".join( | |
| [ | |
| "### Try It", | |
| "", | |
| "**Mode:** released model output (bundled example)", | |
| "", | |
| "These per-turn memory records were precomputed with the released PRISM-Memory adapter. The model is trained to write memory turn by turn, then let retrieval use the accumulated store later.", | |
| ] | |
| ) | |
| return status, pd.DataFrame(rows), _memory_markdown(_dedupe_items(final_memory)), _question_markdown(session) | |
| turns = _parse_transcript(normalized_transcript) | |
| if not turns: | |
| return ( | |
| "Could not parse the transcript. Use one turn per line, for example `[2025-03-01] Dana: We have 20 concurrent jobs max.`", | |
| pd.DataFrame(), | |
| "No memory records yet.", | |
| "", | |
| ) | |
| extractor = _preview_extract_turn | |
| mode = "preview extractor" | |
| note = ( | |
| "This runtime is using a lightweight turn-by-turn preview that follows the same extraction contract. " | |
| "The bundled example sessions above use actual released-model outputs." | |
| ) | |
| try: | |
| _load_live_stack() | |
| extractor = _live_extract_turn | |
| mode = "released model (live)" | |
| note = "This runtime successfully loaded the released adapter and is extracting memory turn by turn." | |
| except Exception: | |
| pass | |
| rows = [] | |
| final_memory: list[str] = [] | |
| context_lines: list[str] = [] | |
| for index, turn in enumerate(turns, start=1): | |
| context = "\n".join(context_lines[-6:]) | |
| extracted = extractor(turn, context)[:5] | |
| rows.append( | |
| { | |
| "turn": index, | |
| "date": turn.get("date", ""), | |
| "speaker": turn["speaker"], | |
| "memory_records": "\n".join(extracted), | |
| } | |
| ) | |
| final_memory.extend( | |
| [f"[{turn['date']}] {record}" if turn.get("date") else record for record in extracted] | |
| ) | |
| context_lines.append(f"[{turn.get('date', '')}] {turn['speaker']}: {turn['text']}") | |
| status = "\n".join( | |
| [ | |
| "### Try It", | |
| "", | |
| f"**Mode:** {mode}", | |
| "", | |
| note, | |
| "", | |
| "Expected transcript format: one turn per line as `[YYYY-MM-DD] Speaker: message` or `Speaker: message`.", | |
| ] | |
| ) | |
| return status, pd.DataFrame(rows), _memory_markdown(_dedupe_items(final_memory)), "" | |
| INITIAL_TRY_IT_CHOICES = try_it_choices() | |
| INITIAL_TRY_IT_CHOICE = INITIAL_TRY_IT_CHOICES[0] if INITIAL_TRY_IT_CHOICES else "" | |
| INITIAL_TRY_IT_TRANSCRIPT, INITIAL_TRY_IT_INTRO = load_try_it_session(INITIAL_TRY_IT_CHOICE) if INITIAL_TRY_IT_CHOICE else ("", "No bundled example sessions available.") | |
| INITIAL_TRY_IT_STATUS, INITIAL_TRY_IT_DF, INITIAL_TRY_IT_MEMORY, INITIAL_TRY_IT_QA = ( | |
| run_try_it(INITIAL_TRY_IT_CHOICE, INITIAL_TRY_IT_TRANSCRIPT) | |
| if INITIAL_TRY_IT_CHOICE | |
| else ("No bundled example sessions available.", pd.DataFrame(), "No memory records yet.", "") | |
| ) | |
| INITIAL_EXAMPLE_CHOICES = example_choices() | |
| INITIAL_EXAMPLE_CHOICE = INITIAL_EXAMPLE_CHOICES[0] if INITIAL_EXAMPLE_CHOICES else "pending" | |
| INITIAL_EXAMPLE_MD = render_example(INITIAL_EXAMPLE_CHOICE) if INITIAL_EXAMPLE_CHOICES else "No extraction examples available yet." | |
| with gr.Blocks(title="PRISM-Memory Demo") as demo: | |
| gr.Markdown(release_markdown()) | |
| with gr.Tab("Metrics"): | |
| gr.Markdown("## Released Model") | |
| metrics = gr.Dataframe(value=summary_df(), interactive=False, wrap=True) | |
| gr.Markdown("## Category Breakdown") | |
| categories = gr.Dataframe(value=category_df(), interactive=False, wrap=True) | |
| refresh = gr.Button("Refresh Data") | |
| refresh.click(fn=lambda: (summary_df(), category_df()), outputs=[metrics, categories]) | |
| with gr.Tab("Try It"): | |
| gr.Markdown( | |
| "\n".join( | |
| [ | |
| "Use one of the bundled sessions or paste your own transcript.", | |
| "", | |
| "PRISM-Memory is trained to write memory **turn by turn**, not to summarize a whole session in one shot.", | |
| ] | |
| ) | |
| ) | |
| choices = INITIAL_TRY_IT_CHOICES or ["No bundled sessions"] | |
| session_picker = gr.Dropdown(choices=choices, value=choices[0], label="Example Session") | |
| session_intro = gr.Markdown(value=INITIAL_TRY_IT_INTRO) | |
| transcript_box = gr.Textbox( | |
| label="Transcript", | |
| lines=10, | |
| value=INITIAL_TRY_IT_TRANSCRIPT, | |
| placeholder="[2025-03-01] Dana: We have 20 concurrent jobs max on GitHub Actions right now.", | |
| ) | |
| run_button = gr.Button("Extract Memory") | |
| try_it_status = gr.Markdown(value=INITIAL_TRY_IT_STATUS) | |
| per_turn_df = gr.Dataframe(value=INITIAL_TRY_IT_DF, interactive=False, wrap=True, label="Per-Turn Memory") | |
| final_memory_md = gr.Markdown(value=INITIAL_TRY_IT_MEMORY) | |
| later_question_md = gr.Markdown(value=INITIAL_TRY_IT_QA) | |
| session_picker.change( | |
| load_and_run_session, | |
| inputs=session_picker, | |
| outputs=[transcript_box, session_intro, try_it_status, per_turn_df, final_memory_md, later_question_md], | |
| ) | |
| run_button.click( | |
| run_try_it, | |
| inputs=[session_picker, transcript_box], | |
| outputs=[try_it_status, per_turn_df, final_memory_md, later_question_md], | |
| ) | |
| with gr.Tab("Extraction Examples"): | |
| choices = INITIAL_EXAMPLE_CHOICES or ["pending"] | |
| picker = gr.Dropdown(choices=choices, value=choices[0], label="Held-Out Example") | |
| example_md = gr.Markdown(value=INITIAL_EXAMPLE_MD) | |
| picker.change(render_example, inputs=picker, outputs=example_md) | |
| with gr.Tab("Data"): | |
| gr.Markdown(_load_datasets()) | |
| with gr.Tab("Skill"): | |
| gr.Markdown(_load_skill()) | |
| if __name__ == "__main__": | |
| demo.launch() | |