File size: 4,449 Bytes
417d8a6
 
 
 
e89e6f7
 
381c281
 
e89e6f7
 
417d8a6
e89e6f7
417d8a6
e89e6f7
 
 
 
 
 
 
 
 
 
 
 
417d8a6
 
fe29d17
417d8a6
 
 
 
9bcc73a
417d8a6
9bcc73a
417d8a6
 
 
9bcc73a
417d8a6
 
 
9bcc73a
 
 
 
e89e6f7
9bcc73a
 
 
 
417d8a6
9bcc73a
417d8a6
 
 
e89e6f7
 
 
417d8a6
e89e6f7
 
 
 
381c281
e89e6f7
 
381c281
e89e6f7
 
417d8a6
e89e6f7
 
417d8a6
e89e6f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417d8a6
e89e6f7
 
417d8a6
 
e89e6f7
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from io import StringIO
import openai
import os

# Groq API key from secrets
GROQ_API_KEY = os.getenv("HUGGINGFACE_KEY")

# Function to load the uploaded file (CSV or Excel)
def load_file(uploaded_file):
    """Load data from an uploaded file."""
    try:
        if uploaded_file.type == "text/csv":
            data = pd.read_csv(uploaded_file)
        elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
            data = pd.read_excel(uploaded_file)
        else:
            st.error("Unsupported file type.")
            return None
        return data
    except Exception as e:
        st.error(f"Error loading file: {e}")
        return None

# Function to generate graph based on user query
def generate_graph(data, query):
    """Generate a graph based on user query."""
    try:
        fig, ax = plt.subplots(figsize=(10, 6))
        
        if "correlation" in query.lower():
            # Correlation matrix
            sns.heatmap(data.corr(), annot=True, cmap="coolwarm", ax=ax)
            st.pyplot(fig)
        elif "histogram" in query.lower():
            # Histogram
            column = st.selectbox("Select a column for the histogram", data.columns)
            sns.histplot(data[column], kde=True, ax=ax)
            st.pyplot(fig)
        elif "bar" in query.lower() and "country" in query.lower():
            # Bar chart for countries and units sold
            if 'country' in data.columns and 'units sold' in data.columns:
                country_data = data[['country', 'units sold']].groupby('country').sum().reset_index()
                sns.barplot(x='country', y='units sold', data=country_data, ax=ax, color='skyblue')
                ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
                st.pyplot(fig)
            else:
                st.error("The dataset must contain 'country' and 'units sold' columns.")
        else:
            st.error("Unsupported graph type. Try asking for a correlation matrix, histogram, or bar chart.")
    except Exception as e:
        st.error(f"Error generating graph: {e}")

# Function to query the Groq model (Groq API)
def query_groq_model(prompt):
    """Send a query to the Groq model and get a response."""
    try:
        # Initialize Groq client
        client = Groq(api_key=GROQ_API_KEY)

        # Query Groq model for response
        chat_completion = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model="llama-3.1-8b-instant",  # Replace with your model
        )

        return chat_completion.choices[0].message.content
    except Exception as e:
        st.error(f"Error querying Groq model: {e}")
        return None

# Streamlit App Interface
def main():
    st.set_page_config(page_title="Data Augmentation and Visualization", page_icon="📊", layout="wide")
    
    # Set background image
    st.markdown(
        """
        <style>
        .stApp {
            background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png');
            background-size: cover;
        }
        </style>
        """, unsafe_allow_html=True
    )
    
    st.title("Data Augmentation and Visualization with Groq API")
    st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True)
    
    # File upload section
    uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"])
    
    if uploaded_file is not None:
        # Load and display data
        data = load_file(uploaded_file)
        
        if data is not None:
            st.write("Dataset preview:", data.head())

            # User input for graph generation
            query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and units sold')")

            if query:
                # Generate the graph based on the query
                generate_graph(data, query)

            # User input for Groq model query
            model_query = st.text_input("Ask Groq model a question:")
            
            if model_query:
                # Query the Groq model and display response
                response = query_groq_model(model_query)
                if response:
                    st.write("Groq Model Response:", response)

if __name__ == "__main__":
    main()