File size: 8,104 Bytes
5ea96df
 
 
6f7c8e3
3ab44aa
34f340d
5ea96df
c5654cc
5ea96df
c5654cc
 
 
5ea96df
3ab44aa
c5654cc
 
6f7c8e3
 
 
3ab44aa
6f7c8e3
 
 
 
 
 
3ab44aa
5ea96df
 
6f7c8e3
 
3ab44aa
 
c5654cc
 
5ea96df
 
c5654cc
3ab44aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ea96df
c5654cc
3ab44aa
 
 
 
 
 
6f7c8e3
c5654cc
5ea96df
3ab44aa
 
 
 
 
 
 
 
 
c5654cc
3ab44aa
 
5ea96df
3ab44aa
c5654cc
3ab44aa
c5654cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ea96df
3ab44aa
c5654cc
3ab44aa
c5654cc
 
 
 
 
 
5ea96df
 
c5654cc
5ea96df
c5654cc
3ab44aa
 
c5654cc
3ab44aa
439b957
c5654cc
 
439b957
 
c5654cc
439b957
 
 
 
c5654cc
 
439b957
 
 
 
 
c5654cc
 
439b957
c5654cc
439b957
 
 
 
c5654cc
 
439b957
 
c5654cc
439b957
 
 
c5654cc
439b957
 
 
 
 
 
 
 
 
c5654cc
439b957
 
 
 
3ab44aa
 
c5654cc
3ab44aa
 
 
 
c5654cc
3ab44aa
 
 
 
 
6f7c8e3
 
c5654cc
3ab44aa
6f7c8e3
c5654cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import streamlit as st
from typing import List, Dict
import anthropic
import os
from datetime import datetime
from utils.legal_prompt_generator import LegalPromptGenerator


class ChatInterface:
    def __init__(self, case_manager, vector_store, document_processor):
        """Initialize ChatInterface with all required components."""
        self.case_manager = case_manager
        self.vector_store = vector_store
        self.document_processor = document_processor
        self.prompt_generator = LegalPromptGenerator()

        try:
            api_key = os.getenv("ANTHROPIC_API_KEY")
            if not api_key:
                st.error("Please set the ANTHROPIC_API_KEY environment variable.")
                st.stop()
            self.client = anthropic.Anthropic(api_key=api_key)
        except Exception as e:
            st.error(f"Error initializing Anthropic client: {str(e)}")
            st.stop()

        # Initialize session state
        if "messages" not in st.session_state:
            st.session_state.messages = []
        if "analyzed_documents" not in st.session_state:
            st.session_state.analyzed_documents = []
        if "context_chunks" not in st.session_state:
            st.session_state.context_chunks = []
        if "current_case" not in st.session_state:
            st.session_state.current_case = None

    def render(self):
        """Render the chat interface with document and context management."""
        st.markdown("""
        <style>
        .chat-message {
            padding: 1.5rem;
            border-radius: 0.5rem;
            margin-bottom: 1rem;
            box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        }
        .user-message {
            background-color: #f0f7ff;
            border-left: 4px solid #2B547E;
        }
        .assistant-message {
            background-color: #ffffff;
            border-left: 4px solid #4CAF50;
        }
        .reference-box {
            background-color: #f5f5f5;
            padding: 0.8rem;
            border-radius: 0.3rem;
            font-size: 0.9em;
            margin-top: 0.5rem;
        }
        .document-chunk {
            border-left: 3px solid #2196F3;
            padding-left: 1rem;
            margin: 0.5rem 0;
        }
        </style>
        """, unsafe_allow_html=True)

        # Display active documents in the sidebar
        with st.sidebar:
            st.subheader("📚 Active Documents")
            for doc in st.session_state.analyzed_documents:
                with st.expander(f"📄 {doc['name']}", expanded=False):
                    st.write(f"Type: {doc.get('metadata', {}).get('type', 'Unknown')}")
                    st.write(f"Added: {doc.get('metadata', {}).get('added_at', 'Unknown')}")

        # Display chat history
        for message in st.session_state.messages:
            message_class = "user-message" if message["role"] == "user" else "assistant-message"
            with st.container():
                st.markdown(f"""
                <div class="chat-message {message_class}">
                    {message["content"]}
                    {'<div class="reference-box">' + message.get("references", "") + '</div>' if message.get("references") else ""}
                </div>
                """, unsafe_allow_html=True)

        # Chat input
        if prompt := st.chat_input("Ask about your documents..."):
            self._handle_chat_input(prompt)

    def _handle_chat_input(self, prompt: str):
        """Process user input and generate a response."""
        st.session_state.messages.append({"role": "user", "content": prompt})

        with st.spinner("Analyzing documents and generating a response..."):
            try:
                # Retrieve relevant document chunks
                context_chunks = self.vector_store.similarity_search(
                    query=prompt,
                    k=5,
                    filter_criteria={"metadata.type": [doc["metadata"]["type"] for doc in st.session_state.analyzed_documents]}
                )

                # Generate the response
                response, references = self.generate_response(prompt, context_chunks)

                # Add assistant response
                st.session_state.messages.append({
                    "role": "assistant",
                    "content": response,
                    "references": references
                })

                # Update context for future queries
                st.session_state.context_chunks = context_chunks
            except Exception as e:
                st.error(f"Error generating response: {str(e)}")

    def generate_response(self, prompt: str, context_chunks: List[Dict]) -> tuple[str, str]:
        """Generate a response using the LLM and LegalPromptGenerator."""
        try:
            # Generate structured messages
            messages = self._generate_messages(prompt, context_chunks)

            # Call the LLM
            response = self.client.messages.create(
                model="claude-3",
                max_tokens=2000,
                temperature=0.7,
                messages=messages
            )

            # Format references
            references_html = self._format_references(context_chunks)
            return response.content[0].text, references_html

        except Exception as e:
            st.error(f"Error generating response: {str(e)}")
            return "An error occurred while processing your query.", ""

    def _generate_messages(self, prompt: str, context_chunks: List[Dict]) -> List[Dict]:
        """Generate structured messages for LLM input."""
        # Get case metadata if available
        case_metadata = None
        if st.session_state.current_case:
            case_metadata = self.case_manager.get_case(st.session_state.current_case)

        # Generate system message
        system_message = self.prompt_generator.generate_system_message(
            context_chunks=context_chunks,
            query=prompt,
            case_metadata=case_metadata
        )

        # Generate user message
        context = "\n".join([
            f"Document: {chunk['metadata'].get('title', 'Untitled')}\n"
            f"Section: {chunk['text']}\n"
            for chunk in context_chunks
        ])
        user_message = self.prompt_generator.generate_user_message(prompt, context)

        # Handle follow-up questions
        if st.session_state.messages:
            previous_query = next(
                (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "user"),
                None
            )
            previous_response = next(
                (m["content"] for m in reversed(st.session_state.messages) if m["role"] == "assistant"),
                None
            )
            if previous_query and previous_response:
                user_message = self.prompt_generator.generate_follow_up_prompt(
                    original_query=previous_query,
                    follow_up_query=prompt,
                    previous_response=previous_response,
                    context_chunks=context_chunks
                )

        return [
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message}
        ]

    def _format_references(self, chunks: List[Dict]) -> str:
        """Format references as HTML for display."""
        references = []
        for i, chunk in enumerate(chunks, 1):
            references.append(f"""
            <div class="document-chunk">
                <strong>Reference {i}:</strong> {chunk['metadata'].get('title', 'Untitled')}
                <br/>
                <em>Section:</em> {chunk['text'][:200]}...
            </div>
            """)
        return "\n".join(references)

    def add_analyzed_document(self, doc: Dict):
        """Add a document to session state with metadata tracking."""
        doc['metadata']['added_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        if doc not in st.session_state.analyzed_documents:
            st.session_state.analyzed_documents.append(doc)