TestGradio / main.py
ac2zoom's picture
Upload folder using huggingface_hub
b8ac35d verified
import gradio as gr
import httpx
from loguru import logger
from typing import Optional, List
from pydantic import BaseModel
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from bs4 import BeautifulSoup
import os
import json
import boto3
from os import environ as env
from reqif.parser import ReqIFParser
import shutil
import zipfile
from dotenv import load_dotenv
from typing import Any
from pydantic import BaseModel, Field
load_dotenv()
API_KEY = os.getenv("OPENAI_API_KEY")
class Message(BaseModel):
role: str
content: str
class FileContext(BaseModel):
processed: bool = False
indexed_documents: Any = Field(default=None, exclude=True)
index: Any = Field(default=None, exclude=True)
def extract_text_data(reqif_data):
text_data = []
for spec_object in reqif_data.core_content.req_if_content.spec_objects:
for attribute in spec_object.attributes:
# Check if the attribute value contains XHTML content
if '<xhtml:div>' in str(attribute.value):
# Use BeautifulSoup to parse the XHTML content
soup = BeautifulSoup(str(attribute.value), 'html.parser')
# Extract and append the text content, removing tags
text_data.append(soup.get_text())
else:
# If no XHTML content, append the value directly
text_data.append(str(attribute.value))
return text_data
def reqif_file_processing(root, filename, user_id):
reqif_file_path = os.path.join(root, filename)
reqif_data = ReqIFParser().parse(reqif_file_path)
text_data = extract_text_data(reqif_data)
extracted_text = json.dumps('\n'.join(text_data))
# Define the folder path
data_folder = 'data'
# Check if the folder exists, if not, create it
if not os.path.exists(data_folder):
os.makedirs(data_folder)
# Write the extracted_text to a file in the data folder
file_path = os.path.join(data_folder, f'{filename}.txt')
# Eventually these will be written to S3 and fetched per individual user
with open(file_path, 'w') as file:
file.write(extracted_text)
def make_completion(history: List[dict], nb_retries: int = 3, delay: int = 30) -> Optional[str]:
"""
Sends a request to the ChatGPT API to retrieve a response based on a list of previous messages.
"""
if not history:
logger.error("History is empty, cannot make LLM completion.")
return "No prior conversation to base the response on."
header = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
client = httpx.Client(headers=header)
# Convert history to the expected format if not already done
formatted_messages = [{"role": msg["role"], "content": msg["content"]} for msg in history]
counter = 0
keep_loop = True
while keep_loop:
logger.debug(f"Attempt {counter} for Chat/Completions")
try:
resp = client.post(
url="https://api.openai.com/v1/chat/completions",
json={
"model": "gpt-4",
"top_p": 1.0,
"n": 1,
"stream": False,
"messages": formatted_messages
},
timeout=delay
)
if resp.status_code == 200:
content = resp.json()["choices"][0]["message"]["content"]
logger.debug(f"LLM Response: {content}")
return content
else:
logger.warning(f"API Error: {resp.text}")
keep_loop = False
except Exception as e:
logger.error(f"Exception during API call: {e}")
counter += 1
if counter >= nb_retries:
keep_loop = False
return "Failed to get a response from the LLM."
def process_and_index_file(uploaded_file, file_processed_state: FileContext):
if uploaded_file is None:
return "No file uploaded", None
filename = os.path.basename(uploaded_file.name)
data_folder = 'data'
file_path = os.path.join(data_folder, filename)
if uploaded_file is not None and not file_processed_state.processed:
filename = os.path.basename(uploaded_file.name)
temp_file_path = os.path.join('temp', filename)
if not os.path.exists('temp'):
os.makedirs('temp')
shutil.copyfile(uploaded_file.name, temp_file_path)
data_folder = 'data'
if filename.endswith('.txt'):
if not os.path.exists(data_folder):
os.makedirs(data_folder)
shutil.move(temp_file_path, os.path.join(data_folder, filename))
file_processed_state.processed = True
documents = SimpleDirectoryReader(data_folder).load_data()
index = VectorStoreIndex.from_documents(documents)
# Updating FileContext
file_processed_state.indexed_documents = documents
file_processed_state.index = index
elif filename.endswith('.reqif'):
reqif_data = ReqIFParser().parse(temp_file_path)
text_data = extract_text_data(reqif_data)
extracted_text = '\n'.join(text_data)
with open(os.path.join(data_folder, f"{filename}.txt"), 'w') as file:
file.write(extracted_text)
shutil.move(temp_file_path, os.path.join(data_folder, filename))
file_processed_state.processed = True
documents = SimpleDirectoryReader(data_folder).load_data()
index = VectorStoreIndex.from_documents(documents)
# Updating FileContext
file_processed_state.indexed_documents = documents
file_processed_state.index = index
elif filename.endswith('.reqifz'):
with zipfile.ZipFile(temp_file_path, 'r') as zip_ref:
for member in zip_ref.namelist():
# Check if the file ends with .reqif
if member.endswith('.reqif'):
# Extract only the .reqif file
zip_ref.extract(member, data_folder)
# Assuming there is only one .reqif file of interest, or you break after the first
temp_file_path = os.path.join(data_folder, member)
filename = member
break
if filename.endswith('.reqif') and not file_processed_state.processed:
reqif_data = ReqIFParser().parse(temp_file_path)
text_data = extract_text_data(reqif_data)
extracted_text = '\n'.join(text_data)
with open(os.path.join(data_folder, f"{filename}.txt"), 'w') as file:
file.write(extracted_text)
shutil.move(temp_file_path, os.path.join(data_folder, filename))
file_processed_state.processed = True
documents = SimpleDirectoryReader(data_folder).load_data()
index = VectorStoreIndex.from_documents(documents)
# Updating FileContext
file_processed_state.indexed_documents = documents
file_processed_state.index = index
# Cleanup: Remove all files other than .txt files and directories named 'media'
for file in os.listdir('data'):
file_path = os.path.join('data', file)
if not file.endswith('.txt'):
if os.path.isdir(file_path):
# If the item is a directory, remove it and its contents
shutil.rmtree(file_path)
else:
# If the item is a file, just remove it
os.remove(file_path)
def predict(input, history):
data_folder = 'data'
response = None
if os.listdir(data_folder):
documents = SimpleDirectoryReader("data").load_data()
index = VectorStoreIndex.from_documents(documents)
query_engine = index.as_query_engine()
llm_response = query_engine.query(input)
response = str(llm_response)
print("RAG Response:", response)
history.append({"role": "user", "content": input})
if response:
history.append({"role": "assistant", "content": response})
else:
response = make_completion(history)
print("LLM Response:", response)
history.append({"role": "assistant", "content": response})
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
return messages, history
def s3_upload(s3_checkbox, file_processed_state):
if not s3_checkbox:
return # Exit the function if the checkbox is not checked
# This needs to be pulled dynamically somehow and integrated with this service. Perhaps we include this code within the main Saphira input. To be experimented with further.
project = "SubmarineSpec"
object_key = f'{project}.json'
# Assuming you've already loaded the AWS credentials into your environment
# and imported them here if needed
aws_access_key = os.getenv('AWS_ACCESS_KEY_ID')
aws_secret = os.getenv('AWS_SECRET_ACCESS_KEY')
aws_region = os.getenv('AWS_DEFAULT_REGION')
# Create an S3 client
s3 = boto3.client('s3', aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret, region_name=aws_region)
bucket_name = 'saphira-userprojects'
try:
# Download the data from S3
response = s3.get_object(Bucket=bucket_name, Key=object_key)
data = json.loads(response['Body'].read().decode('utf-8')) # Ensure decoding of the bytes object
# Specify the data folder and check if it exists
data_folder = 'data'
if not os.path.exists(data_folder):
os.makedirs(data_folder)
# Write the data to a file in the data folder
with open(os.path.join(data_folder, f"{project}.txt"), 'w') as file:
# Assuming you want to write the JSON data as a string
# If 'data' is a dictionary, you might want to format it as a string differently
file.write(json.dumps(data)) # Convert the JSON data back into a string
file_processed_state.processed = True
except Exception as e:
print(f"Error downloading from S3: {e}")
file_processed_state.processed = False
# Gradio interface with file input
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="SaphiraGPT")
history_state = gr.State([])
file_processed_state = gr.State(FileContext())
s3_checkbox = gr.Checkbox(label="Load your project data into SaphiraGPT")
with gr.Row():
txt = gr.Textbox(lines=1, show_label=False, placeholder="Enter text and press enter")
with gr.Row():
file_input = gr.File(label="Select file for SaphiraGPT context")
s3_checkbox.change(fn=s3_upload, inputs=[s3_checkbox, file_processed_state], outputs=None)
file_input.change(fn=process_and_index_file, inputs=[file_input, file_processed_state], outputs=None)
txt.submit(predict, inputs=[txt, history_state], outputs=[chatbot, file_processed_state])
demo.launch(share=True)