File size: 6,962 Bytes
2e7ce4e
 
 
 
 
84f14d7
3110ca7
0f05a4c
 
2e7ce4e
 
0f05a4c
 
2b37dd2
0f05a4c
 
 
 
2e7ce4e
84f14d7
 
 
 
 
 
 
 
 
0f05a4c
bc95eaa
2e7ce4e
 
 
 
 
 
84f14d7
2e7ce4e
 
 
 
 
 
 
 
 
 
84f14d7
2e7ce4e
 
0f05a4c
2e7ce4e
bc95eaa
0f05a4c
5a0e252
2b37dd2
0f05a4c
 
2b37dd2
0f05a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc95eaa
 
5a0e252
 
2e7ce4e
 
0f05a4c
 
2e7ce4e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env python3
"""
Simple HTTP Server - Hello World with POST data
"""

import logging
import sys
import sentence_transformers

from flask import Flask, request, jsonify
from flask_cors import CORS
from groq import Groq
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings.sentence_transformer import (
    SentenceTransformerEmbeddings
)
from langchain_community.vectorstores import Chroma

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)


app = Flask(__name__)
CORS(app)


@app.route('/', methods=['GET'])
def index():
    """Health check"""
    logger.info("test1")
    return jsonify({
        'status': 'running',
        'message': 'Hello World API Server'
    })


@app.route('/api/v1/transcript/process', methods=['POST'])
def process():
    """Process POST data and return Hello World"""
    data = request.get_json()
    logger.info("test2")
    return jsonify({
        'message': 'Hello World',
        'received_data': callLlm(data)
    })

def callLlm(data):
    import os
    pdf_folder_location = "ComplianceFile.pdf"

    # Original cell: _KaqrZMObGUc
    pdf_loader = PyPDFLoader(pdf_folder_location)

    # Original cell: EJXwUPWCxM8J
    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        encoding_name='cl100k_base',
        chunk_size=512,
        chunk_overlap=16
    )

    # Original cell: fp9yToSobbZu
    tesla_10k_chunks = pdf_loader.load_and_split(text_splitter)

    # Original cell: _4jUGoUQchrn
    len(tesla_10k_chunks)

    # Original cell: UmbwCxyabfl4
    tesla_10k_collection = 'compliance_collection'

    # Original cell: nwusGdTRxhhP
    embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')

    # Original cell: 972yZSXwcdpH
    vectorstore = Chroma.from_documents(
        tesla_10k_chunks,
        embedding_model,
        collection_name=tesla_10k_collection,
        persist_directory='./compliance_db'
    )

    # Original cell: ebXn_3vjSgVc
    vectorstore.persist()

    # Original cell: Mtor7tsuFtNB
    # Loading the Chroma DB and using the retriever to retreive the chunks just for testing

    # Original cell: 5PIz6XWQSjnY
    vectorstore_persisted = Chroma(
        collection_name=tesla_10k_collection,
        persist_directory='./compliance_db',
        embedding_function=embedding_model
    )


    # Original cell: eVMsWfPVn-fc
    query = data["question"]

    # Original cell: 5mXpN5Gqn-fe
    docs = vectorstore_persisted.similarity_search(query, k=5)

    # Original cell: mIhAU-9Pn-fe
    for i, doc in enumerate(docs):
        logger.info(f"Retrieved chunk {i + 1}: \n")
        logger.info(doc.page_content.replace('\t', ' '))
        logger.info('\n')

    # Set your API key from Colab Secrets
    os.environ["GROQ_API_KEY"] = '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'

    client = Groq()
    model_name = 'openai/gpt-oss-20b'

    # Original cell: GVwgNoHguTMN
    embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')

    # Original cell: E17vG7WJvoaJ
    tesla_10k_collection = 'compliance_collection'

    # Original cell: o3VQmzZnuLzw
    vectorstore_persisted = Chroma(
        collection_name=tesla_10k_collection,
        persist_directory='./compliance_db',
        embedding_function=embedding_model
    )

    # Original cell: C-15bwukuVYU
    retriever = vectorstore_persisted.as_retriever(
        search_type='similarity',
        search_kwargs={'k': 5}
    )

    # Original cell: 26E1QcvAR-OO
    # Retrieve the first two chunks from the vector store
    retrieved_data = vectorstore_persisted.get(
        include=['metadatas', 'embeddings', 'documents'],
        limit=2
    )

    # Display the content and embeddings of the first two chunks
    for i in range(len(retrieved_data['ids'])):
        logger.info(f"Chunk ID: {retrieved_data['ids'][i]}")
        logger.info(f"Chunk Content: {retrieved_data['documents'][i]}")
        logger.info(f"Chunk Embedding (first 10 values): {retrieved_data['embeddings'][i][:10]}")

    # Original cell: LR4dzgL96U0-
    qna_system_message = """
    You are an assistant to a firm who checks if the user input is compliant based on the doc provided.
    User input will need to be compared with the compliant document provided in the context and find the relevant response.
    This context will begin with the token: ###Context.
    The context contains references to specific portions of a document relevant to the user query.

    User questions will begin with the token: ###Question.

    Please answer user questions only using the context provided in the input.
    Do not mention anything about the context in your final answer. Your response should only contain the answer to the question.

    If the answer is not found in the context, respond "I don't know".
    """

    # Original cell: bDexqi8c6Xmm
    qna_user_message_template = """
    ###Context
    Here are some documents that are relevant to the question mentioned below.
    {context}

    ###Question
    {question}
    """

    # Original cell: nsZuE-Xo2dAR
    user_input = data["question"]

    # Original cell: MUBRJsi12e59
    relevant_document_chunks = retriever.get_relevant_documents(user_input)

    # Original cell: 7eH_q5P92gxJ
    len(relevant_document_chunks)

    # Original cell: 1KeoZOE62jF5
    for document in relevant_document_chunks:
        logger.info(document.page_content.replace("\t", " "))
        break


    # Original cell: aHXY6BcV676h
    relevant_document_chunks = retriever.get_relevant_documents(user_input)
    context_list = [d.page_content for d in relevant_document_chunks]
    context_for_query = ". ".join(context_list)

    prompt = [
        {'role': 'system', 'content': qna_system_message},
        {'role': 'user', 'content': qna_user_message_template.format(
            context=context_for_query,
            question=user_input
        )
         }
    ]

    logger.info(prompt)

    try:
        response = client.chat.completions.create(
            model=model_name,
            messages=prompt,
            temperature=0
        )

        prediction = response.choices[0].message.content.strip()
    except Exception as e:
        prediction = f'Sorry, I encountered the following error: \n {e}'
    logger.info(prediction)
    return prediction


if __name__ == '__main__':
    import os
    
    port = int(os.environ.get("PORT", 7860))  # Hugging Face uses port 7860

    logger.info(f"Starting server on port {port}")
    logger.info(f"POST endpoint: http://0.0.0.0:{port}/api/v1/transcript/process")

    app.run(
        host='0.0.0.0',
        port=port,
        debug=False  # Set to False for production
    )