File size: 6,563 Bytes
334c1a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5edd1c1
334c1a6
 
5edd1c1
334c1a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import time
from typing import Dict, List
import cohere
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from ..config.settings import LLM_MODEL


class DocumentSummarizer:
    def __init__(self, retriever, batch_size=4):
        self.batch_size = batch_size
        self.retriever = retriever # Store the retriever here

        self.cohere_client = cohere.ClientV2()

        self.components = {
            'basic_info': "Basic Paper Information",
            'abstract': "Abstract Summary",
            'methods': "Methodology Summary",
            'results': "Key Results",
            'limitations': "Limitations & Future Work",
            'related_work': "Related Work",
            'applications': "Practical Applications",
            'technical': "Technical Details",
            'equations': "Key Equations",
        }

        self.prompts = self._initialize_prompts()


    def _initialize_prompts(self):
        # It's better to explicitly import what you need
        from ..summarization.prompt2 import (
            basic_info_prompt, abstract_prompt,
            methods_prompt, results_prompt, visuals_prompt, limitations_prompt,
            contributions_prompt, related_work_prompt, applications_prompt,
            technical_prompt, quick_summary_prompt, reading_guide_prompt, # quick_summary & reading_guide prompts might be needed
            equations_prompt
        )
        return {
            'basic_info': basic_info_prompt,
            'abstract': abstract_prompt,
            'methods': methods_prompt,
            'results': results_prompt,
            'limitations': limitations_prompt,
            'related_work': related_work_prompt,
            'applications': applications_prompt,
            'technical': technical_prompt,
            'equations': equations_prompt,
        }

    def summarize_text(self, documents: List[Dict], prompt: str, language: str):
        """
        Summarizes the provided documents using the given prompt and language
        via the Cohere Chat API.
        """
        # Use the initialized client
        try:
            response = self.cohere_client.chat(
                model=LLM_MODEL,
                documents=documents, # Pass the list of dicts directly
                messages=[
                    {"role": "system", "content": f"You are an expert summarization AI. Please respond in {language}."},
                    {"role": "user", "content": f"{prompt}"}
                ],
            )
            if response and response.message and response.message.content and response.message.content[0] and response.message.content[0].text:
                 return response.message.content[0].text
            else:
                 return None

        except Exception as e:
            print(f"Error during Cohere API call: {e}")
            return None


    def extract_relevant_documents(self, component: str, filename: str, chunk_size: int):
        """
        Extracts relevant documents for a specific component from the retriever.
        """
        query = f"Analyze the {self.components.get(component, component)} section from the document titled '{filename}'."
        # Use the retriever stored in self.
        # Pass the chunk_size parameter correctly
        try:
            documents = self.retriever.get_relevant_docs(
                chromdb_query=query,
                rerank_query=query,
                filter={'filename': filename},
                chunk_size=chunk_size
            )
            return documents
        except Exception as e:
            print(f"Error during document retrieval for component {component}: {e}")
            return []


    def summerize_document(self, filename: str, language: str, chunk_size: int):
        """
        Summarizes a document by processing each component in parallel.
        """
        start_total = time.time()
        components = list(self.components.keys())
        results = {}
        errors = {} # Track errors

        def process_component(comp):
            try:
                document_chunks = self.extract_relevant_documents(comp, filename, chunk_size)
                prompt = self.prompts.get(comp)
                summary = self.summarize_text(document_chunks, prompt, language)
                return comp, summary, None

            except Exception as e:
                return comp, None, str(e)

        with ThreadPoolExecutor(max_workers=None) as executor:
            # Submit all component tasks
            future_to_component = {executor.submit(process_component, comp): comp for comp in components}

            # Process results as they complete
            for future in as_completed(future_to_component):
                comp = future_to_component[future]
                try:
                    comp_name, result, error = future.result()
                    if result is not None:
                        results[comp_name] = result
                    elif error:
                        errors[comp_name] = error

                except Exception as exc:
                    errors[comp] = str(exc)


        end_total = time.time()
        print(f"\n--- Total summarization time for {filename}: {end_total - start_total:.2f} seconds ---\n")
        compiled = self.compile_summary(filename, results)
        return compiled


    def compile_summary(self, filename: str, results: Dict[str, str]) -> str:
        """
        Compiles a summary for a document by concatenating the results of all requested components.
        Orders sections according to a predefined list.
        """
        # Include all components that might have results, maintaining desired order
        sections_order = [
            'basic_info', 'abstract',
            'methods', 'results', 'equations', 'technical',
            'related_work', 'applications', 'limitations'
        ]

        lines = [f"# Summary of {filename}", f"Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}\n"]
        for section in sections_order:
            # Only add a section if it was processed and returned a result
            if section in results and results[section]:
                # Use .get with a default in case a component was added to results
                # but not self.components (though validate init helps prevent this)
                title = self.components.get(section, section).title()
                lines.append(f"## {title}\n") # Use ## for subheadings
                lines.append(f"{results[section]}\n")

        return "\n".join(lines)