Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate | |
| from langchain.chains import LLMChain | |
| # Set environment variables | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY | |
| os.environ['LANGCHAIN_VERBOSE'] = 'true' | |
| # Histogram-specific reference code | |
| hist_prompt_template = """```python | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import pandas as pd | |
| simulated_data = { | |
| 'Month': ['January', 'February', 'March', 'April', 'May', 'June', | |
| 'July', 'August', 'September', 'October', 'November', 'December'], | |
| 'Executed_Operations': [5, 8, 6, 7, 9, 10, 12, 11, 8, 7, 6, 5] | |
| } | |
| df = pd.DataFrame(simulated_data) | |
| plt.figure(figsize=(10, 6)) | |
| sns.barplot(x='Month', y='Executed_Operations', data=df, palette='viridis') | |
| plt.xticks(rotation=45) | |
| plt.title('Executed Science Operations per Month in 2006') | |
| plt.xlabel('Month') | |
| plt.ylabel('Number of Executed Operations') | |
| plt.tight_layout() | |
| plt.show() | |
| ```""" | |
| # Line graph-specific reference code | |
| graph_prompt_template = """```python | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| data = { | |
| 'Month': ['January', 'February', 'March', 'April', 'May', 'June', | |
| 'July', 'August', 'September', 'October', 'November', 'December'], | |
| 'Executed_Science_Operations': [5, 8, 7, 6, 9, 10, 8, 7, 6, 5, 9, 10], | |
| 'Calibration_Operations': [2, 3, 2, 4, 3, 2, 3, 4, 3, 2, 3, 4] | |
| } | |
| df = pd.DataFrame(data) | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(df['Month'], df['Executed_Science_Operations'], marker='o', label='Executed Science Operations') | |
| plt.plot(df['Month'], df['Calibration_Operations'], marker='s', label='Calibration Operations') | |
| plt.title('Spitzer Space Telescope Operations - 2006') | |
| plt.xlabel('Month') | |
| plt.ylabel('Number of Operations') | |
| plt.xticks(rotation=45) | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.show() | |
| ```""" | |
| # Master prompt with code injection | |
| default_prompt_template = """You are a Python data visualization assistant. | |
| The following is the content of a text file uploaded by a user: | |
| ------------------- | |
| {file_content} | |
| ------------------- | |
| {user_query} | |
| ------------------- | |
| Use this code for reference to generate the code. | |
| {generated_code} | |
| ------------------- | |
| Your task: | |
| - Assume this content is tabular or semi-structured data. | |
| - Generate a valid Python script using matplotlib, seaborn, or plotly based on the user query to visualize the data. | |
| - Modify the code as required to match the user query. | |
| - Only output the code block (no extra explanation). | |
| - The code should be executable as-is. | |
| - Include data parsing if required. | |
| Python Code:""" | |
| prompt = PromptTemplate(input_variables=["file_content", "user_query", "generated_code"], template=default_prompt_template) | |
| human_prompt = HumanMessagePromptTemplate(prompt=prompt) | |
| chat_prompt_template = ChatPromptTemplate.from_messages([human_prompt]) | |
| chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o") | |
| chain = LLMChain(prompt=chat_prompt_template, llm=chat_model) | |
| def generate_and_plot(file, query): | |
| try: | |
| # Read file content | |
| if hasattr(file, "read"): | |
| file_content = file.read().decode("utf-8")[:1000] | |
| elif isinstance(file, str) or hasattr(file, "name"): | |
| file_path = file.name if hasattr(file, "name") else file | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| file_content = f.read()[:1000] | |
| else: | |
| return "Unsupported file type." | |
| # Choose reference code | |
| if "histogram" in query.lower(): | |
| generated_code_hint = hist_prompt_template | |
| elif "line graph" in query.lower() or "graph plot" in query.lower(): | |
| generated_code_hint = graph_prompt_template | |
| else: | |
| generated_code_hint = "" | |
| # Generate code | |
| generated_code = chain.run(file_content=file_content, user_query=query, generated_code=generated_code_hint) | |
| # Execute code | |
| global_env = {} | |
| cleaned_code = generated_code.replace("```python", "").replace("```", "") | |
| exec(cleaned_code, global_env) | |
| # Return plot | |
| if "fig" in global_env: | |
| fig = global_env["fig"] | |
| try: | |
| import plotly.graph_objects as go | |
| if isinstance(fig, go.Figure): | |
| return fig | |
| except ImportError: | |
| pass | |
| if "plt" in global_env: | |
| import matplotlib.pyplot as plt | |
| return plt.gcf() | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots() | |
| ax.text(0.5, 0.5, "No figure was generated", ha='center', va='center') | |
| ax.axis("off") | |
| return fig | |
| except Exception as e: | |
| return f"Error: {e}" | |
| import os | |
| import gradio as gr | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| # from langchain.embeddings.openai import OpenAIEmbeddings | |
| # from langchain.vectorstores import Chroma | |
| from langchain.chains import RetrievalQA | |
| from tempfile import NamedTemporaryFile | |
| # Initialize LLM | |
| llm = ChatOpenAI(temperature=0, model_name="gpt-4o") | |
| # Keep global QA chain | |
| qa_chain = None | |
| def load_pdf_and_create_qa_chain(pdf_file): | |
| global qa_chain | |
| # Save uploaded file to temp | |
| if hasattr(pdf_file, 'read'): | |
| with NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: | |
| tmp_file.write(pdf_file.read()) | |
| tmp_file_path = tmp_file.name | |
| else: | |
| tmp_file_path = pdf_file.name if hasattr(pdf_file, "name") else pdf_file | |
| # Load document | |
| loader = PyPDFLoader(tmp_file_path) | |
| documents = loader.load() | |
| # Split into chunks | |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| texts = text_splitter.split_documents(documents) | |
| from langchain.vectorstores import Chroma | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| # Embed and store in vector DB | |
| embeddings = OpenAIEmbeddings() | |
| db = Chroma.from_documents(texts, embeddings) | |
| # Set up retriever | |
| retriever = db.as_retriever(search_type="similarity", search_kwargs={"k":2}) | |
| # Create RAG QA chain | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True | |
| ) | |
| return "PDF loaded and ready! You can now ask questions." | |
| def ask_question(query): | |
| global qa_chain | |
| if qa_chain is None: | |
| return "Please upload a PDF first." | |
| try: | |
| # result = qa_chain.run(query) | |
| response = qa_chain.invoke({"query": query}) | |
| result = response["result"] | |
| return result | |
| except Exception as e: | |
| return f"Error answering question: {e}" | |
| import gradio as gr | |
| import os | |
| # Assuming `generate_and_plot`, `load_pdf_and_create_qa_chain`, `ask_question` are already defined above | |
| def process_file(file, query): | |
| if file is None: | |
| return "Please upload a file.", None | |
| filename = file.name if hasattr(file, "name") else "" | |
| if filename.endswith(".pdf"): | |
| load_status = load_pdf_and_create_qa_chain(file) | |
| answer = ask_question(query) | |
| return answer, None | |
| elif filename.endswith(".txt") or filename.endswith(".csv"): | |
| plot = generate_and_plot(file, query) | |
| return "Here is your plot:", plot | |
| else: | |
| return "Unsupported file type. Upload a .pdf, .txt, or .csv file.", None | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Astronomy ChatBot with Plotting and Summarizer") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload your file (.txt, .csv, .pdf)", file_types=[".txt", ".csv", ".pdf"]) | |
| query_input = gr.Textbox(label="Enter your question or plotting instruction") | |
| submit_btn = gr.Button("Submit") | |
| output_text = gr.Textbox(label="Response") | |
| output_plot = gr.Plot(label="Generated Plot") | |
| submit_btn.click(fn=process_file, | |
| inputs=[file_input, query_input], | |
| outputs=[output_text, output_plot]) | |
| demo.launch(debug = True, share = True) |