MemoriaLM / backend /modules /artifacts.py
Max Saavedra
Flash cards artifact added.
6697007
import audioop
import json
import re
import wave
import io
from datetime import datetime, timezone
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Optional
from backend.models.schemas import (
ArtifactFileOut,
ArtifactGenerateOut,
ArtifactListOut,
PodcastArtifactOut,
)
from backend.services.llm import llm_service
from backend.services.storage import NotebookStore
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
import soundfile as sf
except ImportError:
sf = None
try:
import torch
except ImportError:
torch = None
try:
from transformers import AutoTokenizer, VitsModel
except ImportError:
AutoTokenizer = None
VitsModel = None
_vits_model = None
_tokenizer = None
def _get_tts_model():
global _vits_model, _tokenizer
if _vits_model is not None and _tokenizer is not None:
return _vits_model, _tokenizer
if torch is None or sf is None or AutoTokenizer is None or VitsModel is None:
raise RuntimeError("TTS dependencies are not installed")
_vits_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
return _vits_model, _tokenizer
def _now() -> str:
return datetime.now(timezone.utc).replace(microsecond=0).isoformat()
def _artifact_dirs(store: NotebookStore, user_id: str, notebook_id: str) -> Dict[str, Path]:
notebook_dir = store.require_notebook_dir(user_id, notebook_id)
root = notebook_dir / "artifacts"
dirs = {
"report": root / "reports",
"quiz": root / "quizzes",
"flashcards": root / "flashcards",
"podcast": root / "podcasts",
}
for d in dirs.values():
d.mkdir(parents=True, exist_ok=True)
return dirs
def _next_index(artifact_dir: Path, prefix: str) -> int:
highest = 0
pattern = re.compile(rf"^{re.escape(prefix)}_(\d+)\.(?:md|mp3)$")
for path in artifact_dir.glob(f"{prefix}_*.*"):
match = pattern.match(path.name)
if match:
highest = max(highest, int(match.group(1)))
return highest + 1
def _artifact_file_out(path: Path) -> ArtifactFileOut:
created = datetime.fromtimestamp(path.stat().st_mtime, tz=timezone.utc).replace(microsecond=0).isoformat()
return ArtifactFileOut(name=path.name, path=str(path.as_posix()), created_at=created)
def _collect_source_texts(store: NotebookStore, user_id: str, notebook_id: str, max_chars: int = 14000) -> List[Dict[str, str]]:
extracted_dir = store.files_extracted_dir(user_id, notebook_id)
sources: List[Dict[str, str]] = []
consumed = 0
for text_path in sorted(extracted_dir.glob("*.txt")):
source_id = text_path.stem
meta_path = extracted_dir / f"{source_id}.meta.json"
source_name = text_path.name
if meta_path.exists():
try:
payload = json.loads(meta_path.read_text(encoding="utf-8"))
source_name = str(payload.get("source_name") or source_name)
except Exception:
pass
text = text_path.read_text(encoding="utf-8", errors="ignore").strip()
if not text:
continue
remaining = max_chars - consumed
if remaining <= 0:
break
excerpt = text[:remaining]
consumed += len(excerpt)
sources.append({"source_name": source_name, "text": excerpt})
return sources
def _sources_block(sources: List[Dict[str, str]]) -> str:
blocks = []
for source in sources:
blocks.append(f"[Source: {source['source_name']}]\n{source['text']}")
return "\n\n".join(blocks)
def _llm_or_fallback(prompt: str, fallback_text: str) -> str:
try:
result = llm_service.generate(prompt).strip()
if result:
return result
except Exception:
pass
return fallback_text
def _report_fallback(sources: List[Dict[str, str]], extra_prompt: Optional[str]) -> str:
lines = [
"# Study Report",
"",
"## Overview",
"This report was generated from your notebook sources.",
]
if extra_prompt:
lines.extend(["", "## Focus", extra_prompt.strip()])
lines.extend(["", "## Key Source Notes"])
for src in sources[:5]:
snippet = src["text"].replace("\n", " ")[:280].strip()
lines.append(f"- **{src['source_name']}**: {snippet}")
lines.extend(["", "## Conclusion", "Use the chat tab to ask follow-up questions with citations."])
return "\n".join(lines).strip() + "\n"
def _quiz_fallback(sources: List[Dict[str, str]], num_questions: int) -> str:
questions = []
answers = []
for i in range(1, num_questions + 1):
src = sources[(i - 1) % len(sources)]
snippet = src["text"].replace("\n", " ")[:220].strip()
questions.append(f"{i}. Which source contains this idea: \"{snippet}\"?")
answers.append(f"{i}. {src['source_name']}")
return "\n".join(
[
"# Quiz",
"",
"## Questions",
*questions,
"",
"## Answer Key",
*answers,
"",
]
)
def _podcast_transcript_fallback(sources: List[Dict[str, str]], extra_prompt: Optional[str]) -> str:
focus = extra_prompt.strip() if extra_prompt else "core concepts from the notebook"
lines = [
"# Podcast Transcript",
"",
f"Topic focus: {focus}",
"",
"**Host:** Welcome back. Today we cover the key ideas from your notebook.",
]
for idx, src in enumerate(sources[:6], start=1):
snippet = src["text"].replace("\n", " ")[:180].strip()
lines.append(f"**Co-Host:** From {src['source_name']}, point {idx}: {snippet}")
lines.append(f"**Host:** Great, and why does that matter in practice?")
lines.append("**Co-Host:** That wraps the study summary. Review the report and quiz next.")
lines.append("")
return "\n".join(lines)
def _flashcards_fallback(sources: List[Dict[str, str]], num_cards: int) -> str:
cards = max(3, min(20, int(num_cards)))
lines = ["# Flashcards", ""]
for i in range(1, cards + 1):
src = sources[(i - 1) % len(sources)]
snippet = src["text"].replace("\n", " ")[:200].strip()
lines.append(f"## Card {i}")
lines.append(f"Q: What key point appears in {src['source_name']}?")
lines.append(f"A: {snippet}")
lines.append("")
return "\n".join(lines).strip() + "\n"
def _encode_pcm_to_mp3(pcm_16le: bytes, sample_rate: int, channels: int) -> bytes:
import lameenc
encoder = lameenc.Encoder()
encoder.set_bit_rate(96)
encoder.set_in_sample_rate(sample_rate)
encoder.set_channels(channels)
encoder.set_quality(2)
return encoder.encode(pcm_16le) + encoder.flush()
def _wav_bytes_to_mp3(wav_bytes: bytes) -> bytes:
with wave.open(BytesIO(wav_bytes), "rb") as wav_file:
channels = wav_file.getnchannels()
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
frames = wav_file.readframes(wav_file.getnframes())
if sample_width != 2:
frames = audioop.lin2lin(frames, sample_width, 2)
if channels not in (1, 2):
channels = 1
return _encode_pcm_to_mp3(frames, sample_rate, channels)
def _is_mp3(audio_bytes: bytes) -> bool:
if not audio_bytes:
return False
return audio_bytes.startswith(b"ID3") or audio_bytes[:2] in {b"\xff\xfb", b"\xff\xf3", b"\xff\xf2"}
def clean_transcript_for_tts(transcript: str) -> str:
lines = []
for line in transcript.splitlines():
line = line.strip()
if not line:
continue
# Remove markdown formatting and speaker labels.
line = re.sub(r"\*\*Host:\*\*", " ", line, flags=re.IGNORECASE)
line = re.sub(r"\*\*Co-Host:\*\*", " ", line, flags=re.IGNORECASE)
line = re.sub(r"[*_#>`]", " ", line)
line = re.sub(r'[^\x00-\x7F]+', ' ', line)
lines.append(line)
return " ".join(lines)
def _synthesize_podcast_mp3(transcript_text: str) -> bytes:
tts_text = clean_transcript_for_tts(transcript_text)[:1800]
model, tokenizer = _get_tts_model()
inputs = tokenizer(tts_text, return_tensors="pt")
with torch.no_grad():
waveform = model(**inputs).waveform.squeeze().cpu().numpy()
wav_buffer = io.BytesIO()
sf.write(wav_buffer, waveform, 16000, format="WAV")
wav_bytes = wav_buffer.getvalue()
return _wav_bytes_to_mp3(wav_bytes)
def _write_markdown_artifact(artifact_dir: Path, prefix: str, content: str) -> Path:
idx = _next_index(artifact_dir, prefix)
path = artifact_dir / f"{prefix}_{idx}.md"
path.write_text(content.strip() + "\n", encoding="utf-8")
return path
def generate_report(
store: NotebookStore,
*,
user_id: str,
notebook_id: str,
prompt: Optional[str] = None,
) -> ArtifactGenerateOut:
dirs = _artifact_dirs(store, user_id, notebook_id)
sources = _collect_source_texts(store, user_id, notebook_id)
if not sources:
raise ValueError("No ingested sources available. Upload or ingest sources first.")
focus = prompt.strip() if prompt else "Produce a complete study report with clear sections."
source_context = _sources_block(sources)
llm_prompt = (
"Create a markdown study report grounded only in SOURCES.\n"
"Include: title, summary, core concepts, examples, and a short conclusion.\n"
"Cite source names inline when useful.\n\n"
f"FOCUS:\n{focus}\n\n"
f"SOURCES:\n{source_context}\n"
)
content = _llm_or_fallback(llm_prompt, _report_fallback(sources, prompt))
out_path = _write_markdown_artifact(dirs["report"], "report", content)
return ArtifactGenerateOut(
artifact_type="report",
message=f"Generated {out_path.name}",
markdown_path=str(out_path.as_posix()),
audio_path=None,
created_at=_now(),
)
def generate_quiz(
store: NotebookStore,
*,
user_id: str,
notebook_id: str,
prompt: Optional[str] = None,
num_questions: int = 8,
) -> ArtifactGenerateOut:
dirs = _artifact_dirs(store, user_id, notebook_id)
sources = _collect_source_texts(store, user_id, notebook_id)
if not sources:
raise ValueError("No ingested sources available. Upload or ingest sources first.")
questions = max(3, min(15, int(num_questions)))
focus = prompt.strip() if prompt else "Create a mixed-difficulty study quiz."
llm_prompt = (
"Create a markdown quiz from SOURCES.\n"
f"Include exactly {questions} questions and a final 'Answer Key' section.\n"
"Ground each question in source content.\n\n"
f"FOCUS:\n{focus}\n\n"
f"SOURCES:\n{_sources_block(sources)}\n"
)
content = _llm_or_fallback(llm_prompt, _quiz_fallback(sources, questions))
out_path = _write_markdown_artifact(dirs["quiz"], "quiz", content)
return ArtifactGenerateOut(
artifact_type="quiz",
message=f"Generated {out_path.name}",
markdown_path=str(out_path.as_posix()),
audio_path=None,
created_at=_now(),
)
def generate_flashcards(
store: NotebookStore,
*,
user_id: str,
notebook_id: str,
prompt: Optional[str] = None,
num_questions: int = 8,
) -> ArtifactGenerateOut:
dirs = _artifact_dirs(store, user_id, notebook_id)
sources = _collect_source_texts(store, user_id, notebook_id)
if not sources:
raise ValueError("No ingested sources available. Upload or ingest sources first.")
cards = max(3, min(20, int(num_questions)))
focus = prompt.strip() if prompt else "Create concise study flashcards."
llm_prompt = (
"Create markdown flashcards from SOURCES.\n"
f"Include exactly {cards} cards in format:\n"
"## Card N\nQ: ...\nA: ...\n"
"Keep answers concise and grounded in source content.\n\n"
f"FOCUS:\n{focus}\n\n"
f"SOURCES:\n{_sources_block(sources)}\n"
)
content = _llm_or_fallback(llm_prompt, _flashcards_fallback(sources, cards))
out_path = _write_markdown_artifact(dirs["flashcards"], "flashcards", content)
return ArtifactGenerateOut(
artifact_type="flashcards",
message=f"Generated {out_path.name}",
markdown_path=str(out_path.as_posix()),
audio_path=None,
created_at=_now(),
)
def generate_podcast(
store: NotebookStore,
*,
user_id: str,
notebook_id: str,
prompt: Optional[str] = None,
) -> ArtifactGenerateOut:
dirs = _artifact_dirs(store, user_id, notebook_id)
sources = _collect_source_texts(store, user_id, notebook_id)
if not sources:
raise ValueError("No ingested sources available. Upload or ingest sources first.")
focus = prompt.strip() if prompt else "Generate a conversational podcast between two hosts."
llm_prompt = (
"Write a markdown podcast transcript with a two-person conversation.\n"
"Use **Host:** and **Co-Host:** labels.\n"
"Keep it factual and grounded in SOURCES.\n\n"
f"FOCUS:\n{focus}\n\n"
f"SOURCES:\n{_sources_block(sources)}\n"
)
transcript = _llm_or_fallback(llm_prompt, _podcast_transcript_fallback(sources, prompt))
idx = _next_index(dirs["podcast"], "podcast")
transcript_path = dirs["podcast"] / f"podcast_{idx}.md"
audio_path = dirs["podcast"] / f"podcast_{idx}.mp3"
transcript_path.write_text(transcript.strip() + "\n", encoding="utf-8")
audio_bytes = _synthesize_podcast_mp3(transcript)
audio_path.write_bytes(audio_bytes)
return ArtifactGenerateOut(
artifact_type="podcast",
message=f"Generated podcast_{idx}.md and podcast_{idx}.mp3",
markdown_path=str(transcript_path.as_posix()),
audio_path=str(audio_path.as_posix()),
created_at=_now(),
)
def list_artifacts(store: NotebookStore, *, user_id: str, notebook_id: str) -> ArtifactListOut:
dirs = _artifact_dirs(store, user_id, notebook_id)
reports = [_artifact_file_out(p) for p in sorted(dirs["report"].glob("report_*.md"))]
quizzes = [_artifact_file_out(p) for p in sorted(dirs["quiz"].glob("quiz_*.md"))]
flashcards = [_artifact_file_out(p) for p in sorted(dirs["flashcards"].glob("flashcards_*.md"))]
podcast_indices: set[int] = set()
for path in dirs["podcast"].glob("podcast_*.*"):
match = re.match(r"podcast_(\d+)\.(?:md|mp3)$", path.name)
if match:
podcast_indices.add(int(match.group(1)))
podcasts: List[PodcastArtifactOut] = []
for idx in sorted(podcast_indices):
transcript_path = dirs["podcast"] / f"podcast_{idx}.md"
audio_path = dirs["podcast"] / f"podcast_{idx}.mp3"
podcasts.append(
PodcastArtifactOut(
transcript=_artifact_file_out(transcript_path) if transcript_path.exists() else None,
audio=_artifact_file_out(audio_path) if audio_path.exists() else None,
)
)
return ArtifactListOut(reports=reports, quizzes=quizzes, flashcards=flashcards, podcasts=podcasts)
def resolve_artifact_path(
store: NotebookStore,
*,
user_id: str,
notebook_id: str,
artifact_type: str,
filename: str,
) -> Path:
if not filename or "/" in filename or "\\" in filename:
raise ValueError("Invalid filename")
dirs = _artifact_dirs(store, user_id, notebook_id)
kind = artifact_type.strip().lower()
if kind not in dirs:
raise ValueError("Unsupported artifact type")
path = dirs[kind] / filename
if not path.exists() or not path.is_file():
raise FileNotFoundError("Artifact file not found")
return path