Spaces:
Runtime error
Runtime error
| import os | |
| import openai | |
| import pandas as pd | |
| import gradio as gr | |
| import uuid | |
| import json | |
| from pathlib import Path | |
| from huggingface_hub import CommitScheduler, HfApi | |
| from openai import OpenAI | |
| from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| #------------------------------------------------------------------------------------- | |
| def get_answer (question, quotes, temperature, document): | |
| yield "Running... Analyzing Question", "", question | |
| with open('./templates/question_analysis.txt', 'r') as file: | |
| question_analysis = file.read() | |
| with open('./templates/question_analysis_template.txt', 'r') as file: | |
| question_analysis_template = file.read() | |
| q_analysis = [ | |
| {"role": "system", "content": question_analysis}, | |
| {"role": "user", "content": question_analysis_template.format( | |
| question=question, | |
| ) | |
| } | |
| ] | |
| try: | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| messages=q_analysis, | |
| max_tokens=2000, | |
| temperature=0.0 | |
| ) | |
| if response.choices[0].message.content == "Valid Question.": | |
| yield "Running... Question Analysis Done", "", question | |
| else: | |
| yield "Stopped: Question Analysis Done", "The question is not valid, stopping the process", "" | |
| return | |
| except openai.OpenAIError as e: | |
| print(f"An error occurred: {str(e)}") | |
| return | |
| with open('./templates/qna.txt', 'r') as file: | |
| qna = file.read() | |
| with open('./templates/qna_template.txt', 'r') as file: | |
| qna_template = file.read() | |
| filename = "/content/dataset/" + document | |
| quotes = vector_db.similarity_search(question, k=quotes, filter = {"source":filename}) | |
| context_for_query = "" | |
| for i, d in enumerate(quotes, start=1): | |
| context_for_query += f"Quote {i}:\n" | |
| context_for_query += d.page_content + "\n" | |
| context_for_query += f"(Page = {d.metadata.get('page', 'Unknown')})\n\n" | |
| answer_to_analyze = [ | |
| {"role": "system", "content": qna}, | |
| {"role": "user", "content": qna_template.format( | |
| context=context_for_query, | |
| question=question | |
| ) | |
| } | |
| ] | |
| yield "Running... Getting best answer from AI", "", question | |
| try: | |
| answer_analyzed = client.chat.completions.create( | |
| model=model_name, | |
| messages=answer_to_analyze, | |
| max_tokens=2000, | |
| temperature=temperature | |
| ) | |
| yield "Stopped... Process Finished", answer_analyzed.choices[0].message.content, "" | |
| except openai.OpenAIError as e: | |
| print(f"An error occurred: {str(e)}") | |
| return | |
| log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" | |
| log_folder = log_file.parent | |
| scheduler = CommitScheduler( | |
| repo_id="GL-Project3_Logs", | |
| repo_type="dataset", | |
| folder_path=log_folder, | |
| path_in_repo="data", | |
| every=2, | |
| token=hf_token | |
| ) | |
| with scheduler.lock: | |
| with log_file.open("a") as f: | |
| f.write(json.dumps( | |
| { | |
| 'user_input': question, | |
| 'retrieved_context': context_for_query, | |
| 'model_response': answer_analyzed.choices[0].message.content | |
| } | |
| )) | |
| f.write("\n") | |
| #------------------------------------------------------------------------------------- | |
| hf_token = os.getenv("HF_TOKEN") | |
| openai_api = os.getenv("OPENAI_API_KEY") | |
| client=OpenAI( | |
| api_key=openai_api | |
| ) | |
| model_name = 'gpt-3.5-turbo' | |
| embedding_model = SentenceTransformerEmbeddings(model_name="thenlper/gte-large") | |
| vectordb_location = './companies-10K-2023_db1' | |
| collection_name = 'companies-10K-2023' | |
| vector_db = Chroma( | |
| collection_name=collection_name, | |
| embedding_function=embedding_model, | |
| persist_directory=vectordb_location | |
| ) | |
| stored_documents = vector_db.get(include=["metadatas"]) | |
| sources = set() | |
| document_names = set() | |
| for metadata in stored_documents['metadatas']: | |
| source = metadata.get('source', 'No source found') | |
| document_names.add(os.path.basename(source)) | |
| document_list = list(document_names) | |
| #------------------------------------------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("GL - Project 3: RAG") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| document_dropdown = gr.Dropdown( | |
| choices=document_list, | |
| label="Document", | |
| ) | |
| question_input = gr.Textbox( | |
| label="Enter your question", | |
| placeholder="Type your question here...", | |
| ) | |
| with gr.Column(scale=1): | |
| quotes_to_fetch = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| label="How many quotes you want from the source", | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness: 0 = deterministic, 1 = creative/unexpected answers. If you can't get an answer try increasing the temperature." | |
| ) | |
| with gr.Row(): | |
| fetch_answer = gr.Button("Analyze and Answer") | |
| with gr.Row(): | |
| answer_output = gr.Textbox( | |
| label="Answer", | |
| placeholder="Your answer will be displayed here..." | |
| ) | |
| fetch_answer.click( | |
| get_answer, | |
| inputs=[question_input, quotes_to_fetch, temperature_slider, document_dropdown], | |
| outputs=[fetch_answer, answer_output, question_input] | |
| ) | |
| demo.launch(share=True, show_error=True, debug=True) |