Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Simple HTTP Server - Hello World with POST data | |
| """ | |
| import logging | |
| import sys | |
| import sentence_transformers | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from groq import Groq | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.embeddings.sentence_transformer import ( | |
| SentenceTransformerEmbeddings | |
| ) | |
| from langchain_community.vectorstores import Chroma | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = Flask(__name__) | |
| CORS(app) | |
| def index(): | |
| """Health check""" | |
| logger.info("test1") | |
| return jsonify({ | |
| 'status': 'running', | |
| 'message': 'Hello World API Server' | |
| }) | |
| def process(): | |
| """Process POST data and return Hello World""" | |
| data = request.get_json() | |
| logger.info("test2") | |
| return jsonify({ | |
| 'message': 'Hello World', | |
| 'received_data': callLlm(data) | |
| }) | |
| def callLlm(data): | |
| import os | |
| pdf_folder_location = "ComplianceFile.pdf" | |
| # Original cell: _KaqrZMObGUc | |
| pdf_loader = PyPDFLoader(pdf_folder_location) | |
| # Original cell: EJXwUPWCxM8J | |
| text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| encoding_name='cl100k_base', | |
| chunk_size=512, | |
| chunk_overlap=16 | |
| ) | |
| # Original cell: fp9yToSobbZu | |
| tesla_10k_chunks = pdf_loader.load_and_split(text_splitter) | |
| # Original cell: _4jUGoUQchrn | |
| len(tesla_10k_chunks) | |
| # Original cell: UmbwCxyabfl4 | |
| tesla_10k_collection = 'compliance_collection' | |
| # Original cell: nwusGdTRxhhP | |
| embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') | |
| # Original cell: 972yZSXwcdpH | |
| vectorstore = Chroma.from_documents( | |
| tesla_10k_chunks, | |
| embedding_model, | |
| collection_name=tesla_10k_collection, | |
| persist_directory='./compliance_db' | |
| ) | |
| # Original cell: ebXn_3vjSgVc | |
| vectorstore.persist() | |
| # Original cell: Mtor7tsuFtNB | |
| # Loading the Chroma DB and using the retriever to retreive the chunks just for testing | |
| # Original cell: 5PIz6XWQSjnY | |
| vectorstore_persisted = Chroma( | |
| collection_name=tesla_10k_collection, | |
| persist_directory='./compliance_db', | |
| embedding_function=embedding_model | |
| ) | |
| # Original cell: eVMsWfPVn-fc | |
| query = data["question"] | |
| # Original cell: 5mXpN5Gqn-fe | |
| docs = vectorstore_persisted.similarity_search(query, k=5) | |
| # Original cell: mIhAU-9Pn-fe | |
| for i, doc in enumerate(docs): | |
| logger.info(f"Retrieved chunk {i + 1}: \n") | |
| logger.info(doc.page_content.replace('\t', ' ')) | |
| logger.info('\n') | |
| # Set your API key from Colab Secrets | |
| os.environ["GROQ_API_KEY"] = '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' | |
| client = Groq() | |
| model_name = 'openai/gpt-oss-20b' | |
| # Original cell: GVwgNoHguTMN | |
| embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') | |
| # Original cell: E17vG7WJvoaJ | |
| tesla_10k_collection = 'compliance_collection' | |
| # Original cell: o3VQmzZnuLzw | |
| vectorstore_persisted = Chroma( | |
| collection_name=tesla_10k_collection, | |
| persist_directory='./compliance_db', | |
| embedding_function=embedding_model | |
| ) | |
| # Original cell: C-15bwukuVYU | |
| retriever = vectorstore_persisted.as_retriever( | |
| search_type='similarity', | |
| search_kwargs={'k': 5} | |
| ) | |
| # Original cell: 26E1QcvAR-OO | |
| # Retrieve the first two chunks from the vector store | |
| retrieved_data = vectorstore_persisted.get( | |
| include=['metadatas', 'embeddings', 'documents'], | |
| limit=2 | |
| ) | |
| # Display the content and embeddings of the first two chunks | |
| for i in range(len(retrieved_data['ids'])): | |
| logger.info(f"Chunk ID: {retrieved_data['ids'][i]}") | |
| logger.info(f"Chunk Content: {retrieved_data['documents'][i]}") | |
| logger.info(f"Chunk Embedding (first 10 values): {retrieved_data['embeddings'][i][:10]}") | |
| # Original cell: LR4dzgL96U0- | |
| qna_system_message = """ | |
| You are an assistant to a firm who checks if the user input is compliant based on the doc provided. | |
| User input will need to be compared with the compliant document provided in the context and find the relevant response. | |
| This context will begin with the token: ###Context. | |
| The context contains references to specific portions of a document relevant to the user query. | |
| User questions will begin with the token: ###Question. | |
| Please answer user questions only using the context provided in the input. | |
| Do not mention anything about the context in your final answer. Your response should only contain the answer to the question. | |
| If the answer is not found in the context, respond "I don't know". | |
| """ | |
| # Original cell: bDexqi8c6Xmm | |
| qna_user_message_template = """ | |
| ###Context | |
| Here are some documents that are relevant to the question mentioned below. | |
| {context} | |
| ###Question | |
| {question} | |
| """ | |
| # Original cell: nsZuE-Xo2dAR | |
| user_input = data["question"] | |
| # Original cell: MUBRJsi12e59 | |
| relevant_document_chunks = retriever.get_relevant_documents(user_input) | |
| # Original cell: 7eH_q5P92gxJ | |
| len(relevant_document_chunks) | |
| # Original cell: 1KeoZOE62jF5 | |
| for document in relevant_document_chunks: | |
| logger.info(document.page_content.replace("\t", " ")) | |
| break | |
| # Original cell: aHXY6BcV676h | |
| relevant_document_chunks = retriever.get_relevant_documents(user_input) | |
| context_list = [d.page_content for d in relevant_document_chunks] | |
| context_for_query = ". ".join(context_list) | |
| prompt = [ | |
| {'role': 'system', 'content': qna_system_message}, | |
| {'role': 'user', 'content': qna_user_message_template.format( | |
| context=context_for_query, | |
| question=user_input | |
| ) | |
| } | |
| ] | |
| logger.info(prompt) | |
| try: | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| messages=prompt, | |
| temperature=0 | |
| ) | |
| prediction = response.choices[0].message.content.strip() | |
| except Exception as e: | |
| prediction = f'Sorry, I encountered the following error: \n {e}' | |
| logger.info(prediction) | |
| return prediction | |
| if __name__ == '__main__': | |
| import os | |
| port = int(os.environ.get("PORT", 7860)) # Hugging Face uses port 7860 | |
| logger.info(f"Starting server on port {port}") | |
| logger.info(f"POST endpoint: http://0.0.0.0:{port}/api/v1/transcript/process") | |
| app.run( | |
| host='0.0.0.0', | |
| port=port, | |
| debug=False # Set to False for production | |
| ) | |