indrasn0wal's picture
Update app.py
1a254b3 verified
import os
import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.chains import LLMChain
# Set environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
os.environ['LANGCHAIN_VERBOSE'] = 'true'
# Histogram-specific reference code
hist_prompt_template = """```python
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
simulated_data = {
'Month': ['January', 'February', 'March', 'April', 'May', 'June',
'July', 'August', 'September', 'October', 'November', 'December'],
'Executed_Operations': [5, 8, 6, 7, 9, 10, 12, 11, 8, 7, 6, 5]
}
df = pd.DataFrame(simulated_data)
plt.figure(figsize=(10, 6))
sns.barplot(x='Month', y='Executed_Operations', data=df, palette='viridis')
plt.xticks(rotation=45)
plt.title('Executed Science Operations per Month in 2006')
plt.xlabel('Month')
plt.ylabel('Number of Executed Operations')
plt.tight_layout()
plt.show()
```"""
# Line graph-specific reference code
graph_prompt_template = """```python
import matplotlib.pyplot as plt
import pandas as pd
data = {
'Month': ['January', 'February', 'March', 'April', 'May', 'June',
'July', 'August', 'September', 'October', 'November', 'December'],
'Executed_Science_Operations': [5, 8, 7, 6, 9, 10, 8, 7, 6, 5, 9, 10],
'Calibration_Operations': [2, 3, 2, 4, 3, 2, 3, 4, 3, 2, 3, 4]
}
df = pd.DataFrame(data)
plt.figure(figsize=(10, 6))
plt.plot(df['Month'], df['Executed_Science_Operations'], marker='o', label='Executed Science Operations')
plt.plot(df['Month'], df['Calibration_Operations'], marker='s', label='Calibration Operations')
plt.title('Spitzer Space Telescope Operations - 2006')
plt.xlabel('Month')
plt.ylabel('Number of Operations')
plt.xticks(rotation=45)
plt.legend()
plt.tight_layout()
plt.show()
```"""
# Master prompt with code injection
default_prompt_template = """You are a Python data visualization assistant.
The following is the content of a text file uploaded by a user:
-------------------
{file_content}
-------------------
{user_query}
-------------------
Use this code for reference to generate the code.
{generated_code}
-------------------
Your task:
- Assume this content is tabular or semi-structured data.
- Generate a valid Python script using matplotlib, seaborn, or plotly based on the user query to visualize the data.
- Modify the code as required to match the user query.
- Only output the code block (no extra explanation).
- The code should be executable as-is.
- Include data parsing if required.
Python Code:"""
prompt = PromptTemplate(input_variables=["file_content", "user_query", "generated_code"], template=default_prompt_template)
human_prompt = HumanMessagePromptTemplate(prompt=prompt)
chat_prompt_template = ChatPromptTemplate.from_messages([human_prompt])
chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o")
chain = LLMChain(prompt=chat_prompt_template, llm=chat_model)
def generate_and_plot(file, query):
try:
# Read file content
if hasattr(file, "read"):
file_content = file.read().decode("utf-8")[:1000]
elif isinstance(file, str) or hasattr(file, "name"):
file_path = file.name if hasattr(file, "name") else file
with open(file_path, "r", encoding="utf-8") as f:
file_content = f.read()[:1000]
else:
return "Unsupported file type."
# Choose reference code
if "histogram" in query.lower():
generated_code_hint = hist_prompt_template
elif "line graph" in query.lower() or "graph plot" in query.lower():
generated_code_hint = graph_prompt_template
else:
generated_code_hint = ""
# Generate code
generated_code = chain.run(file_content=file_content, user_query=query, generated_code=generated_code_hint)
# Execute code
global_env = {}
cleaned_code = generated_code.replace("```python", "").replace("```", "")
exec(cleaned_code, global_env)
# Return plot
if "fig" in global_env:
fig = global_env["fig"]
try:
import plotly.graph_objects as go
if isinstance(fig, go.Figure):
return fig
except ImportError:
pass
if "plt" in global_env:
import matplotlib.pyplot as plt
return plt.gcf()
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "No figure was generated", ha='center', va='center')
ax.axis("off")
return fig
except Exception as e:
return f"Error: {e}"
import os
import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
# from langchain.embeddings.openai import OpenAIEmbeddings
# from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from tempfile import NamedTemporaryFile
# Initialize LLM
llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
# Keep global QA chain
qa_chain = None
def load_pdf_and_create_qa_chain(pdf_file):
global qa_chain
# Save uploaded file to temp
if hasattr(pdf_file, 'read'):
with NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(pdf_file.read())
tmp_file_path = tmp_file.name
else:
tmp_file_path = pdf_file.name if hasattr(pdf_file, "name") else pdf_file
# Load document
loader = PyPDFLoader(tmp_file_path)
documents = loader.load()
# Split into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
from langchain.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings
# Embed and store in vector DB
embeddings = OpenAIEmbeddings()
db = Chroma.from_documents(texts, embeddings)
# Set up retriever
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k":2})
# Create RAG QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
return "PDF loaded and ready! You can now ask questions."
def ask_question(query):
global qa_chain
if qa_chain is None:
return "Please upload a PDF first."
try:
# result = qa_chain.run(query)
response = qa_chain.invoke({"query": query})
result = response["result"]
return result
except Exception as e:
return f"Error answering question: {e}"
import gradio as gr
import os
# Assuming `generate_and_plot`, `load_pdf_and_create_qa_chain`, `ask_question` are already defined above
def process_file(file, query):
if file is None:
return "Please upload a file.", None
filename = file.name if hasattr(file, "name") else ""
if filename.endswith(".pdf"):
load_status = load_pdf_and_create_qa_chain(file)
answer = ask_question(query)
return answer, None
elif filename.endswith(".txt") or filename.endswith(".csv"):
plot = generate_and_plot(file, query)
return "Here is your plot:", plot
else:
return "Unsupported file type. Upload a .pdf, .txt, or .csv file.", None
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("Astronomy ChatBot with Plotting and Summarizer")
with gr.Row():
file_input = gr.File(label="Upload your file (.txt, .csv, .pdf)", file_types=[".txt", ".csv", ".pdf"])
query_input = gr.Textbox(label="Enter your question or plotting instruction")
submit_btn = gr.Button("Submit")
output_text = gr.Textbox(label="Response")
output_plot = gr.Plot(label="Generated Plot")
submit_btn.click(fn=process_file,
inputs=[file_input, query_input],
outputs=[output_text, output_plot])
demo.launch(debug = True, share = True)