| import streamlit as st |
| import pandas as pd |
| import os |
| from crewai import Agent, Task, Crew |
| from langchain_groq import ChatGroq |
| import streamlit_ace as st_ace |
| import traceback |
| import contextlib |
| import io |
| from crewai_tools import FileReadTool |
| import matplotlib.pyplot as plt |
| import glob |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
| |
| groq_api_key = os.getenv("GROQ_API_KEY") |
|
|
|
|
| def main(): |
| |
| set_custom_css() |
|
|
| |
| if 'edited_code' not in st.session_state: |
| st.session_state['edited_code'] = "" |
| |
| |
| if 'code_generated' not in st.session_state: |
| st.session_state['code_generated'] = False |
|
|
| |
| st.markdown(""" |
| <div class="header"> |
| <h1>CrewAI Machine Learning Assistant</h1> |
| <p>Your AI-powered partner for machine learning projects.</p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| |
| st.sidebar.title('Customization') |
| model = st.sidebar.selectbox( |
| 'Choose a model', |
| ['llama3-8b-8192', "llama3-70b-8192"] |
| ) |
|
|
| |
| llm = initialize_llm(model) |
|
|
| |
|
|
| |
| user_question = st.text_area("Describe your ML problem:", key="user_question") |
| uploaded_file = st.file_uploader("Upload a sample .csv of your data (optional)", key="uploaded_file") |
| try: |
| file_name = uploaded_file.name |
| except: |
| file_name = "dataset.csv" |
|
|
| |
| agents = initialize_agents(llm,file_name) |
| |
| if uploaded_file: |
| try: |
| df = pd.read_csv(uploaded_file) |
| st.write("Data successfully uploaded:") |
| st.dataframe(df.head()) |
| data_upload = True |
| except Exception as e: |
| st.error(f"Error reading the file: {e}") |
| data_upload = False |
| else: |
| df = None |
| data_upload = False |
|
|
| |
| if st.button('Process'): |
| tasks = create_tasks("Process",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], None, agents) |
| with st.spinner('Processing...'): |
| crew = Crew( |
| agents=list(agents.values()), |
| tasks=tasks, |
| verbose=2 |
| ) |
|
|
| result = crew.kickoff() |
|
|
| if result: |
| code = result.strip("```") |
| try: |
| filt_idx = code.index("```") |
| code = code[:filt_idx] |
| except: |
| pass |
| st.session_state['edited_code'] = code |
| st.session_state['code_generated'] = True |
|
|
| st.session_state['edited_code'] = st_ace.st_ace( |
| value=st.session_state['edited_code'], |
| language='python', |
| theme='monokai', |
| keybinding='vscode', |
| min_lines=20, |
| max_lines=50 |
| ) |
|
|
| if st.session_state['code_generated']: |
| |
| suggestion = st.text_area("Suggest modifications to the generated code (optional):", key="suggestion") |
| if st.button('Modify'): |
| if st.session_state['edited_code'] and suggestion: |
| tasks = create_tasks("Modify",user_question,file_name, data_upload, df, suggestion, st.session_state['edited_code'], None, agents) |
| with st.spinner('Modifying code...'): |
| crew = Crew( |
| agents=list(agents.values()), |
| tasks=tasks, |
| verbose=2 |
| ) |
|
|
| result = crew.kickoff() |
|
|
| if result: |
| code = result.strip("```") |
| try: |
| filter_idx = code.index("```") |
| code = code[:filter_idx] |
| except: |
| pass |
| st.session_state['edited_code'] = code |
|
|
| st.write("Modified code:") |
| st.session_state['edited_code']= st_ace.st_ace( |
| value=st.session_state['edited_code'], |
| language='python', |
| theme='monokai', |
| keybinding='vscode', |
| min_lines=20, |
| max_lines=50 |
| ) |
|
|
| debugger = st.text_area("Paste error message here for debugging (optional):", key="debugger") |
| if st.button('Debug'): |
| if st.session_state['edited_code'] and debugger: |
| tasks = create_tasks("Debug",user_question,file_name, data_upload, df, None, st.session_state['edited_code'], debugger, agents) |
| with st.spinner('Debugging code...'): |
| crew = Crew( |
| agents=list(agents.values()), |
| tasks=tasks, |
| verbose=2 |
| ) |
|
|
| result = crew.kickoff() |
|
|
| if result: |
| code = result.strip("```") |
| try: |
| filter_idx = code.index("```") |
| code = code[:filter_idx] |
| except: |
| pass |
| st.session_state['edited_code'] = code |
|
|
| st.write("Debugged code:") |
| st.session_state['edited_code'] = st_ace.st_ace( |
| value=st.session_state['edited_code'], |
| language='python', |
| theme='monokai', |
| keybinding='vscode', |
| min_lines=20, |
| max_lines=50 |
| ) |
|
|
| if st.button('Run'): |
| output = io.StringIO() |
| with contextlib.redirect_stdout(output): |
| try: |
| globals().update({'dataset': df}) |
| final_code = st.session_state["edited_code"] |
| |
| with st.expander("Final Code"): |
| st.code(final_code, language='python') |
|
|
| exec(final_code, globals()) |
| result = output.getvalue() |
| success = True |
| except Exception as e: |
| result = str(e) |
| success = False |
|
|
| st.subheader('Output:') |
| st.text(result) |
|
|
| figs = [manager.canvas.figure for manager in plt._pylab_helpers.Gcf.get_all_fig_managers()] |
| if figs: |
| st.subheader('Generated Plots:') |
| for fig in figs: |
| st.pyplot(fig) |
|
|
| if success: |
| st.success("Code executed successfully!") |
| else: |
| st.error("Code execution failed! Waiting for debugging input...") |
|
|
| |
| with st.sidebar: |
| st.header('Output Files:') |
| files = glob.glob(os.path.join("Output/", '*')) |
| for file in files: |
| if os.path.isfile(file): |
| with open(file, 'rb') as f: |
| st.download_button(label=f'Download {os.path.basename(file)}', data=f, file_name=os.path.basename(file)) |
|
|
|
|
|
|
| |
| def set_custom_css(): |
| st.markdown(""" |
| <style> |
| body { |
| background: #0e0e0e; |
| color: #e0e0e0; |
| font-family: 'Roboto', sans-serif; |
| } |
| .header { |
| background: linear-gradient(135deg, #6e3aff, #b839ff); |
| padding: 10px; |
| border-radius: 10px; |
| } |
| .header h1, .header p { |
| color: white; |
| text-align: center; |
| } |
| .stButton button { |
| background-color: #b839ff; |
| color: white; |
| border-radius: 10px; |
| font-size: 16px; |
| padding: 10px 20px; |
| } |
| .stButton button:hover { |
| background-color: #6e3aff; |
| color: #e0e0e0; |
| } |
| .spinner { |
| display: flex; |
| justify-content: center; |
| align-items: center; |
| } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| |
| def initialize_llm(model): |
| return ChatGroq( |
| temperature=0, |
| groq_api_key=groq_api_key, |
| model_name=model |
| ) |
|
|
| |
| def initialize_agents(llm,file_name): |
| file_read_tool = FileReadTool() |
| return { |
| "Data_Reader_Agent": Agent( |
| role='Data_Reader_Agent', |
| goal="Read the uploaded dataset and provide it to other agents.", |
| backstory="Responsible for reading the uploaded dataset.", |
| verbose=True, |
| allow_delegation=False, |
| llm=llm, |
| tools=[file_read_tool] |
| ), |
| "Problem_Definition_Agent": Agent( |
| role='Problem_Definition_Agent', |
| goal="Clarify the machine learning problem the user wants to solve.", |
| backstory="Expert in defining machine learning problems.", |
| verbose=True, |
| allow_delegation=False, |
| llm=llm, |
| ), |
| "EDA_Agent": Agent( |
| role='EDA_Agent', |
| goal="Perform all possible Exploratory Data Analysis (EDA) on the data provided by the user.", |
| backstory="Specializes in conducting comprehensive EDA to understand the data characteristics, distributions, and relationships.", |
| verbose=True, |
| allow_delegation=False, |
| llm=llm, |
| ), |
| "Feature_Engineering_Agent": Agent( |
| role='Feature_Engineering_Agent', |
| goal="Perform feature engineering on the data based on the EDA results provided by the EDA agent.", |
| backstory="Expert in deriving new features, transforming existing features, and preprocessing data to prepare it for modeling.", |
| verbose=True, |
| allow_delegation=False, |
| llm=llm, |
| ), |
| "Model_Recommendation_Agent": Agent( |
| role='Model_Recommendation_Agent', |
| goal="Suggest the most suitable machine learning models.", |
| backstory="Expert in recommending machine learning algorithms.", |
| verbose=True, |
| allow_delegation=False, |
| llm=llm, |
| ), |
| "Starter_Code_Generator_Agent": Agent( |
| role='Starter_Code_Generator_Agent', |
| goal=f"Generate starter Python code for the project. Always give dataset name as {file_name}", |
| backstory="Code wizard for generating starter code templates.", |
| verbose=True, |
| allow_delegation=False, |
| llm=llm, |
| ), |
| "Code_Modification_Agent": Agent( |
| role='Code_Modification_Agent', |
| goal="Modify the generated Python code based on user suggestions.", |
| backstory="Expert in adapting code according to user feedback.", |
| verbose=True, |
| allow_delegation=False, |
| llm=llm, |
| ), |
| |
| |
| |
| |
| |
| |
| |
| |
| "Code_Debugger_Agent": Agent( |
| role='Code_Debugger_Agent', |
| goal="Debug the generated Python code.", |
| backstory="Seasoned code debugger.", |
| verbose=True, |
| allow_delegation=False, |
| llm=llm, |
| ), |
| "Compiler_Agent":Agent( |
| role = "Code_compiler", |
| goal = "Extract only the python code.", |
| backstory = "You are the compiler which extract only the python code.", |
| verbose = True, |
| allow_delegation = False, |
| llm = llm |
| ) |
| } |
|
|
| |
| def create_tasks(func_call,user_question,file_name, data_upload, df, suggestion, edited_code, debugger, agents): |
| info = df.info() |
| tasks = [] |
| if(func_call == "Process"): |
| tasks.append(Task( |
| description=f"Clarify the ML problem: {user_question}", |
| agent=agents["Problem_Definition_Agent"], |
| expected_output="A clear and concise definition of the ML problem." |
| ) |
| ) |
| |
| if data_upload: |
| tasks.extend([ |
| Task( |
| description=f"Evaluate the data provided by the file name . This is the data: {df}", |
| agent=agents["EDA_Agent"], |
| expected_output="An assessment of the EDA and preprocessing like dataset info, missing value, duplication, outliers etc. on the data provided" |
| ), |
| Task( |
| description=f"Feature Engineering on data {df} based on EDA output: {info}", |
| agent=agents["Feature_Engineering_Agent"], |
| expected_output="An assessment of the Featuring Engineering and preprocessing like handling missing values, handling duplication, handling outliers, feature encoding, feature scaling etc. on the data provided" |
| ) |
| ]) |
|
|
| tasks.extend([ |
| Task( |
| description="Suggest suitable ML models.", |
| agent=agents["Model_Recommendation_Agent"], |
| expected_output="A list of suitable ML models." |
| ), |
| Task( |
| description=f"Generate starter Python code based on feature engineering, where column names are {df.columns.tolist()}. Generate only the code without any extra text", |
| agent=agents["Starter_Code_Generator_Agent"], |
| expected_output="Starter Python code." |
| ), |
| ]) |
| if(func_call == "Modify"): |
| if suggestion: |
| tasks.append( |
| Task( |
| description=f"Modify the already generated code {edited_code} according to the suggestion: {suggestion} \n\n Do not generate entire new code.", |
| agent=agents["Code_Modification_Agent"], |
| expected_output="Modified code." |
| ) |
| ) |
| if(func_call == "Debug"): |
| if debugger: |
| tasks.append( |
| Task( |
| description=f"Debug and fix any errors for data with column names {df.columns.tolist()} with data as {df} in the generated code: {edited_code} \n\n According to the debugging: {debugger}. \n\n Do not generate entire new code. Just remove the error in the code by modifying only necessary parts of the code.", |
| agent=agents["Code_Debugger_Agent"], |
| expected_output="Debugged and successfully executed code." |
| ) |
| ) |
| tasks.append( |
| Task( |
| description = "Your job is to only extract python code from string", |
| agent = agents["Compiler_Agent"], |
| expected_output = "Running python code." |
| ) |
| ) |
|
|
| return tasks |
|
|
| if __name__ == "__main__": |
| main() |
|
|