File size: 3,919 Bytes
417d8a6
 
 
 
2f92c73
417d8a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f92c73
 
 
 
 
 
 
 
 
 
 
 
 
 
417d8a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f92c73
417d8a6
2f92c73
 
552b8ff
417d8a6
2f92c73
 
417d8a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Configure page
st.set_page_config(page_title="Data Augmentation App", layout="wide")
st.markdown(
    f"""
    <style>
        .reportview-container {{
            background: url("https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png");
            background-size: cover;
        }}
        footer {{
            text-align: left;
        }}
    </style>
    """,
    unsafe_allow_html=True,
)

st.title("Data Augmentation and Analysis App")
st.sidebar.title("Upload Your File")
st.sidebar.markdown("Supported formats: CSV, Excel")

# Get the Hugging Face API key from secrets
hf_api_key = st.secrets.get("HUGGINGFACE_KEY")
if not hf_api_key:
    st.error("Hugging Face API key not found in secrets.")
else:
    # Initialize the model and tokenizer using the API key
    try:
        model_name = "llama3-70b-8192"  # Replace with the correct model name if needed
        model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=hf_api_key)
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_api_key)
        llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
        st.success(f"Model {model_name} initialized successfully!")
    except Exception as e:
        st.error(f"Error initializing model: {e}")

def load_file(uploaded_file):
    """Load the uploaded file."""
    if uploaded_file.name.endswith('.csv'):
        return pd.read_csv(uploaded_file)
    elif uploaded_file.name.endswith('.xlsx'):
        return pd.read_excel(uploaded_file)
    else:
        st.error("Unsupported file format. Please upload a CSV or Excel file.")
        return None

def generate_graph(data, query):
    """Generate a graph based on user query."""
    try:
        fig, ax = plt.subplots(figsize=(10, 6))
        if "correlation" in query.lower():
            sns.heatmap(data.corr(), annot=True, cmap="coolwarm", ax=ax)
            st.pyplot(fig)
        elif "histogram" in query.lower():
            column = st.selectbox("Select a column for the histogram", data.columns)
            sns.histplot(data[column], kde=True, ax=ax)
            st.pyplot(fig)
        else:
            st.error("Unsupported graph type. Try asking for a correlation matrix or histogram.")
    except Exception as e:
        st.error(f"Error generating graph: {e}")

def handle_query(data, query):
    """Handle user query using the LLM."""
    try:
        if not llm_pipeline:
            st.error("LLM pipeline is not initialized. Check for errors in setup.")
            return
        prompt = f"Given the dataset: {data.to_dict(orient='records')}, answer the following: {query}"
        response = llm_pipeline(prompt, max_length=200, num_return_sequences=1)
        st.write("Response:", response[0]['generated_text'])
    except Exception as e:
        st.error(f"Error in LLM processing: {e}")

# Main App
uploaded_file = st.sidebar.file_uploader("Upload your file here", type=["csv", "xlsx"])
if uploaded_file:
    data = load_file(uploaded_file)
    if data is not None:
        st.write("Dataset Preview")
        st.dataframe(data)

        query = st.text_area("Ask your question about the dataset")
        if query:
            if "table" in query.lower():
                st.write("Table Preview")
                st.write(data)
            elif "graph" in query.lower():
                generate_graph(data, query)
            elif "predict" in query.lower():
                st.write("Prediction functionality is in progress.")
            else:
                handle_query(data, query)

footer = """
    <div style='text-align: left; padding: 10px;'>
        <footer>Created by: Shamil Shahbaz</footer>
    </div>
"""
st.markdown(footer, unsafe_allow_html=True)