File size: 7,653 Bytes
aacd162
 
 
 
 
 
46e5b37
aacd162
 
 
 
 
f09ce8f
aacd162
 
 
 
 
f09ce8f
 
 
 
 
 
 
 
 
 
 
aacd162
 
f09ce8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aacd162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f09ce8f
 
aacd162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f09ce8f
 
 
 
 
 
 
 
aacd162
 
f09ce8f
aacd162
 
 
 
f09ce8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aacd162
 
46e5b37
 
aacd162
 
46e5b37
aacd162
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
Report generator using RAG context from ingested notebook content.
"""
from __future__ import annotations

import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

from dotenv import load_dotenv
from openai import OpenAI
import requests

from src.ingestion.vectorstore import ChromaAdapter

load_dotenv()

SUPPORTED_REPORT_LLM_PROVIDERS = {"openai", "groq", "ollama"}
DEFAULT_REPORT_MODELS = {
    "openai": "gpt-4o-mini",
    "groq": "llama-3.1-8b-instant",
    "ollama": "qwen2.5:3b",
}
REPORT_SYSTEM_PROMPT = (
    "You write high quality reports grounded only in provided source context. "
    "Do not invent facts."
)


class ReportGenerator:
    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        llm_provider: Optional[str] = None,
    ):
        provider_default = (
            llm_provider
            or os.getenv("REPORT_LLM_PROVIDER", "").strip()
            or os.getenv("QUIZ_LLM_PROVIDER", "").strip()
            or os.getenv("TRANSCRIPT_LLM_PROVIDER", "").strip()
            or "openai"
        )
        self.llm_provider = provider_default.strip().lower()
        if self.llm_provider not in SUPPORTED_REPORT_LLM_PROVIDERS:
            raise ValueError(
                f"Unsupported REPORT_LLM_PROVIDER='{self.llm_provider}'. "
                f"Choose from: {sorted(SUPPORTED_REPORT_LLM_PROVIDERS)}"
            )

        self.model = self._resolve_model_name(model)
        self._openai_client: OpenAI | None = None
        self._groq_client = None
        self._ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://127.0.0.1:11434").rstrip("/")

        if self.llm_provider == "openai":
            self.api_key = api_key or os.getenv("OPENAI_API_KEY")
            self._openai_client = OpenAI(api_key=self.api_key)
        elif self.llm_provider == "groq":
            from groq import Groq

            groq_api_key = os.getenv("GROQ_API_KEY")
            if not groq_api_key:
                raise ValueError("GROQ_API_KEY is required when REPORT_LLM_PROVIDER=groq")
            self._groq_client = Groq(api_key=groq_api_key)
        else:
            self.api_key = None

    def _resolve_model_name(self, explicit_model: Optional[str]) -> str:
        if explicit_model and explicit_model.strip():
            return explicit_model.strip()
        configured = os.getenv("REPORT_LLM_MODEL", "").strip()
        if configured:
            return configured
        legacy = os.getenv("LLM_MODEL", "").strip()
        if legacy:
            return legacy
        return DEFAULT_REPORT_MODELS.get(self.llm_provider, "gpt-4o-mini")

    def generate_report(
        self,
        user_id: str,
        notebook_id: str,
        detail_level: str = "medium",
        topic_focus: str | None = None,
    ) -> dict[str, str]:
        context = self._get_report_context(user_id, notebook_id, topic_focus)
        if not context:
            return {"error": "No content found in notebook. Please ingest documents first."}

        report_markdown = self._generate_markdown(context=context, detail_level=detail_level, topic_focus=topic_focus)
        if not report_markdown.strip():
            return {"error": "Failed to generate report content."}

        return {
            "content": report_markdown,
            "detail_level": detail_level,
            "llm_provider": self.llm_provider,
            "llm_model": self.model,
        }

    def _get_report_context(self, user_id: str, notebook_id: str, topic_focus: str | None) -> str:
        storage_base = os.getenv("STORAGE_BASE_DIR", "data")
        chroma_dir = Path(storage_base) / "users" / user_id / "notebooks" / notebook_id / "chroma"
        if not chroma_dir.exists():
            return ""

        store = ChromaAdapter(persist_directory=str(chroma_dir))
        if topic_focus:
            queries = [topic_focus]
        else:
            queries = [
                "main ideas and summary",
                "key evidence and facts",
                "conclusions and action items",
            ]

        chunks: list[str] = []
        for query in queries:
            try:
                results = store.query(user_id, notebook_id, query, top_k=6)
            except Exception:
                continue
            for _, _, chunk_data in results:
                text = str(chunk_data.get("document", "")).strip()
                if text:
                    chunks.append(text)

        if not chunks:
            return ""

        unique_chunks = list(dict.fromkeys(chunks))
        return "\n\n".join(unique_chunks[:14])

    def _generate_markdown(self, context: str, detail_level: str, topic_focus: str | None) -> str:
        target_words = {
            "short": 400,
            "medium": 800,
            "long": 1400,
        }.get(detail_level, 800)

        focus_line = f"Focus area: {topic_focus}" if topic_focus else "Focus area: broad summary of notebook content"
        prompt = f"""
Write a polished Markdown report from the notebook context below.

{focus_line}
Target length: around {target_words} words.

Notebook context:
{context}

Requirements:
- Use Markdown headings and concise sections.
- Include: Executive Summary, Key Insights, Evidence/Examples, Risks or Open Questions, Next Steps.
- Stay faithful to provided context. Do not fabricate unsupported claims.
- Keep tone professional and clear.
- Return Markdown only (no code fences).
"""

        try:
            return self._generate_report_content(prompt)
        except Exception:
            return ""

    def _generate_report_content(self, prompt: str) -> str:
        if self.llm_provider == "openai":
            assert self._openai_client is not None
            response = self._openai_client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": REPORT_SYSTEM_PROMPT},
                    {"role": "user", "content": prompt},
                ],
                temperature=0.4,
            )
            return str(response.choices[0].message.content or "").strip()

        if self.llm_provider == "groq":
            assert self._groq_client is not None
            response = self._groq_client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": REPORT_SYSTEM_PROMPT},
                    {"role": "user", "content": prompt},
                ],
                temperature=0.4,
            )
            return str(response.choices[0].message.content or "").strip()

        payload = {
            "model": self.model,
            "system": REPORT_SYSTEM_PROMPT,
            "prompt": prompt,
            "stream": False,
            "options": {"temperature": 0.4},
        }
        response = requests.post(
            f"{self._ollama_base_url}/api/generate",
            json=payload,
            timeout=120,
        )
        response.raise_for_status()
        body = response.json()
        return str(body.get("response", "")).strip()

    def save_report(self, markdown_text: str, user_id: str, notebook_id: str) -> str:
        storage_base = os.getenv("STORAGE_BASE_DIR", "data")
        report_dir = Path(storage_base) / "users" / user_id / "notebooks" / notebook_id / "artifacts" / "reports"
        report_dir.mkdir(parents=True, exist_ok=True)

        timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
        path = report_dir / f"report_{timestamp}.md"
        path.write_text(markdown_text, encoding="utf-8")
        return str(path)