File size: 3,363 Bytes
332f2d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
from typing import List
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from utils.database import get_collection_documents, get_all_documents, get_embeddings_model, initialize_qa_system

def generate_document_tags(content: str) -> List[str]:
    """Generate tags for a document using AI."""
    try:
        llm = ChatOpenAI(temperature=0.2, model="gpt-3.5-turbo")
        
        prompt = """Analyze the following document content and generate relevant tags/keywords. 
        Focus on key themes, topics, and important terminology.
        Return only the tags as a comma-separated list.
        Content: {content}"""
        
        response = llm.invoke([
            SystemMessage(content="You are a document analysis assistant. Generate relevant tags as a comma-separated list only."),
            HumanMessage(content=prompt.format(content=content[:2000]))
        ])
        
        # Extract content from the AI message
        tags_text = response.content
        # Split the comma-separated string into a list
        tags = [tag.strip() for tag in tags_text.split(',')]
        return tags
    except Exception as e:
        st.error(f"Error generating tags: {e}")
        return []

def initialize_chat_system(collection_id=None) -> bool:
    """Initialize chat system with documents."""
    try:
        # Get documents based on collection or all documents
        documents = (get_collection_documents(st.session_state.db_conn, collection_id)
                     if collection_id else get_all_documents(st.session_state.db_conn))

        if not documents:
            st.error("No documents found.")
            return False

        with st.spinner("Processing documents..."):
            # Initialize embeddings
            embeddings = get_embeddings_model()
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=500,
                chunk_overlap=50,
                length_function=len,
            )

            # Process all documents
            all_chunks = []
            for doc in documents:
                doc_chunks = text_splitter.split_text(doc['content'])
                chunks = [
                    {
                        'content': chunk,
                        'metadata': {
                            'source': doc['name'],
                            'document_id': doc['id'],
                            'collection_id': collection_id
                        }
                    }
                    for chunk in doc_chunks
                ]
                all_chunks.extend(chunks)

            # Create vector store
            vector_store = FAISS.from_texts(
                [chunk['content'] for chunk in all_chunks],
                embeddings,
                [chunk['metadata'] for chunk in all_chunks]
            )

            # Initialize QA system
            st.session_state.vector_store = vector_store
            st.session_state.qa_system = initialize_qa_system(vector_store)
            st.session_state.chat_ready = True
            return True

    except Exception as e:
        st.error(f"Error initializing chat system: {e}")
        return False