File size: 9,430 Bytes
aaf767a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
Testset Generation module for the RAG system using Ragas.
This script generates question-answer pairs from documents to be used in evaluation.

how to run:
    python generate_testset.py <pdf_path1> <pdf_path2> ...
"""

# pylint: disable=import-error,no-name-in-module,invalid-name,broad-except,missing-function-docstring,missing-class-docstring,too-many-return-statements,ungrouped-imports,line-too-long,logging-fstring-interpolation,duplicate-code,too-few-public-methods

import os
import sys
import logging
from typing import List, Any
from PyPDF2 import PdfReader

try:
    # Newer langchain versions expose Document in langchain.schema
    from langchain.schema import Document
except Exception:
    try:
        # Older versions used langchain.docstore.document
        from langchain.docstore.document import Document
    except Exception:
        # Minimal fallback Document for environments without langchain
        from dataclasses import dataclass

        @dataclass
        class Document:
            page_content: str
            metadata: dict | None = None


from ragas.testset.synthesizers.generate import TestsetGenerator

try:
    from langchain.chat_models import ChatOpenAI
except Exception:
    from langchain_openai import ChatOpenAI

try:
    from langchain_huggingface import HuggingFaceEmbeddings
except Exception:
    from langchain_community.embeddings import HuggingFaceEmbeddings

    try:
        from langchain.schema import SystemMessage, HumanMessage
    except Exception:
        # Minimal stand-ins if langchain.schema isn't available
        from dataclasses import dataclass

        @dataclass
        class SystemMessage:
            content: str

        @dataclass
        class HumanMessage:
            content: str

    def _extract_chat_response(resp) -> str:
        """Robust extraction of text from various ChatOpenAI response shapes."""
        try:
            # langchain newer: AIMessage with .content
            if hasattr(resp, "content"):
                return resp.content
            # langchain older/other: ChatResult with .generations
            if hasattr(resp, "generations"):
                gens = resp.generations
                # gens may be list[list[Generation]] or list[Generation]
                try:
                    return gens[0][0].text
                except Exception:
                    try:
                        return gens[0].text
                    except Exception:
                        pass
            # fallback dict/list shapes
            if isinstance(resp, list) and resp:
                first = resp[0]
                if hasattr(first, "content"):
                    return first.content
                if isinstance(first, dict) and "content" in first:
                    return first["content"]
            if isinstance(resp, dict):
                for k in ("content", "text"):
                    if k in resp:
                        return resp[k]
        except Exception:
            pass
        return str(resp)

    def summarize_documents(docs, llm, max_summary_chars: int = 2000) -> List[Document]:
        """Summarize each Document using the provided LLM into shorter Documents.

        This is optional and controlled by the `USE_CHUNK_SUMMARIZATION` env var.
        """
        summaries: List[Document] = []
        for i, doc in enumerate(docs):
            text = (doc.page_content or "").strip()
            if not text:
                continue
            # Construct a concise summarization prompt
            prompt = (
                f"Summarize the following text into a concise summary (preserve key facts, numbers, and named entities). "
                f"Aim for no more than {max_summary_chars} characters. Return only the summary, no commentary.\n\nText:\n"
                + text
            )
            try:
                messages = [
                    SystemMessage(content="You are a concise summarizer."),
                    HumanMessage(content=prompt),
                ]
                resp = llm(messages)
                summary = _extract_chat_response(resp)
            except Exception:
                try:
                    resp = llm(prompt)
                    summary = _extract_chat_response(resp)
                except Exception as e:
                    logging.debug(f"Summarization failed for chunk {i}: {e}")
                    # Fallback: truncate
                    summary = text[:max_summary_chars]

            summary = (summary or "").strip()
            if not summary:
                summary = text[:max_summary_chars]

            meta = dict(doc.metadata) if getattr(doc, "metadata", None) else {}
            meta.update({"chunk": i})
            summaries.append(Document(page_content=summary, metadata=meta))
        return summaries


# Text splitting to avoid sending huge prompts to the LLM
try:
    from langchain.text_splitter import RecursiveCharacterTextSplitter
except Exception:
    # Minimal fallback splitter if langchain isn't available
    class RecursiveCharacterTextSplitter:
        def __init__(self, chunk_size: int = 8000, chunk_overlap: int = 500):
            self.chunk_size = chunk_size
            self.chunk_overlap = chunk_overlap

        def split_documents(self, docs):
            out = []
            for doc in docs:
                text = doc.page_content or ""
                step = max(1, self.chunk_size - self.chunk_overlap)
                for i in range(0, len(text), step):
                    chunk = text[i : i + self.chunk_size]
                    out.append(Document(page_content=chunk, metadata=doc.metadata))
            return out


def get_documents_from_pdfs(pdf_paths: List[str]) -> List[Document]:
    """
    Load PDFs and convert them to LangChain Document objects.

    Parameters
    ----------
    pdf_paths : List[str]
        List of paths to PDF files.

    Returns
    -------
    List[Document]
        List of LangChain Document objects.
    """
    documents = []
    for path in pdf_paths:
        try:
            reader = PdfReader(path)
            text = ""
            for page in reader.pages:
                page_text = page.extract_text()
                if page_text:
                    text += page_text
            source = os.path.basename(path)
            documents.append(Document(page_content=text, metadata={"source": source}))
        except Exception as e:
            logging.error(f"Error reading {path}: {e}")
    return documents


def generate_testset(
    pdf_paths: List[str], test_size: int = 10, output_path: str = "testset.csv"
) -> Any:
    """
    Generate a test set from the given PDFs.

    Parameters
    ----------
    pdf_paths : List[str]
        List of paths to PDF files.
    test_size : int, optional
        Number of QA pairs to generate.
    output_path : str, optional
        Path to save the generated test set (CSV).

    Returns
    -------
    Any
        The generated test set.
    """
    documents = get_documents_from_pdfs(pdf_paths)
    if not documents:
        logging.error("No documents found to generate testset from.")
        return None

    # Configure LLM and Embeddings consistent with the app
    # Use environment variables for API keys and Base URL (e.g. standard OPENAI_*, or manually set)

    # Allow overriding the LLM model via env var
    model_name = os.getenv("TESTSET_LLM_MODEL", "openai/gpt-4o-mini")
    logging.info(f"Using LLM model: {model_name}")

    # Prefer OpenRouter when available so generated LLM clients use it by default.
    _openrouter_key = os.getenv("OPENROUTER_API_KEY")
    if _openrouter_key:
        os.environ["OPENAI_API_BASE"] = "https://api.openrouter.ai/v1"
        os.environ["OPENAI_API_KEY"] = _openrouter_key
        logging.info(
            "OpenRouter detected; routing OpenAI calls via %s",
            os.environ["OPENAI_API_BASE"],
        )
        logging.info(
            "OPENAI_API_KEY loaded=%s",
            bool(os.environ.get("OPENAI_API_KEY")),
        )

    # Create LLM clients (will read credentials from environment)
    generator_llm = ChatOpenAI(model=model_name)
    # Note: critic_llm would be used for test evaluation if needed in future
    # critic_llm = ChatOpenAI(model=model_name)

    embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")

    # Initialize generator (provide the generator LLM and the embeddings)
    generator = TestsetGenerator.from_langchain(generator_llm, embeddings)

    # Split large documents into smaller chunks to avoid exceeding model context limits
    splitter = RecursiveCharacterTextSplitter(chunk_size=8000, chunk_overlap=500)
    split_docs = splitter.split_documents(documents)

    # Generate testset (use default query distribution)
    logging.info(
        f"Generating testset of size {test_size} from {len(split_docs)} chunks..."
    )
    testset = generator.generate_with_langchain_docs(split_docs, testset_size=test_size)

    # Export to CSV
    test_df = testset.to_pandas()
    test_df.to_csv(output_path, index=False)
    logging.info(f"Testset saved to {output_path}")

    return testset


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    if len(sys.argv) < 2:
        print("Usage: python generate_testset.py <pdf_path1> <pdf_path2> ...")
    else:
        pdf_files = sys.argv[1:]
        generate_testset(pdf_files)