Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import networkx as nx | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from pyArango.connection import Connection | |
| from arango import ArangoClient | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_community.graphs import ArangoGraph | |
| from langchain.chains import ArangoGraphQAChain | |
| from langchain_core.prompts import ChatPromptTemplate | |
| import os | |
| # Initialize Google Gemini | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash-thinking-exp", | |
| temperature=0, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2, | |
| ) | |
| # ArangoDB Connection | |
| arangoURL=os.environ["ARANGO_URL"] | |
| username=os.environ["ARANGO_USER"] | |
| password=os.environ["ARANGO_PASSWORD"] | |
| dbName=os.environ["DB_NAME"] | |
| # Streamlit App Configuration | |
| st.set_page_config(page_title="Shark Tank Analytics", layout="wide") | |
| # Initialize Session State | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Page Layout | |
| st.title("🦈 Shark Tank Analytics Platform") | |
| st.write("Powered by ArangoDB") | |
| tab1, tab2 = st.tabs(["Chat Interface", "Analytics Dashboard"]) | |
| # Chat Tab | |
| with tab1: | |
| st.header("Chat with Shark Tank Data") | |
| # Chat Input | |
| user_query = st.chat_input("Ask about investments, sharks, or startups...") | |
| # Initialize ArangoGraph | |
| db = ArangoClient(hosts=arangoURL).db(dbName, | |
| username, password, verify=True) | |
| graph = ArangoGraph(db) | |
| # LangChain Q&A Chain | |
| chain = ArangoGraphQAChain.from_llm(llm, graph=graph,allow_dangerous_requests=True) | |
| # Display Chat History with latest messages at the top | |
| for message in st.session_state.chat_history: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Process Query | |
| if user_query: | |
| # Add user message to history (at the beginning) | |
| st.session_state.chat_history.insert(0, {"role": "user", "content": user_query}) | |
| with st.chat_message("user"): | |
| st.markdown(user_query) | |
| # Process with LangChain | |
| try: | |
| response = chain.run(user_query) | |
| with st.chat_message("assistant"): | |
| # Enhanced response formatting | |
| if isinstance(response, dict): | |
| if "result" in response: | |
| if isinstance(response["result"], list): | |
| df = pd.DataFrame(response["result"]) | |
| st.dataframe(df) | |
| st.session_state.chat_history.insert(0, { | |
| "role": "assistant", | |
| "content": f"Found {len(df)} results" | |
| }) | |
| else: | |
| st.write(response["result"]) | |
| st.session_state.chat_history.insert(0, { | |
| "role": "assistant", | |
| "content": response["result"] | |
| }) | |
| else: | |
| st.write(response) | |
| st.session_state.chat_history.insert(0, { | |
| "role": "assistant", | |
| "content": str(response) | |
| }) | |
| else: | |
| st.write(response) | |
| st.session_state.chat_history.insert(0, { | |
| "role": "assistant", | |
| "content": response | |
| }) | |
| except Exception as e: | |
| st.error(f"Error processing query: {str(e)}") | |
| st.session_state.chat_history.insert(0, { | |
| "role": "assistant", | |
| "content": f"Error: {str(e)}" | |
| }) | |
| # Dashboard Tab | |
| with tab2: | |
| st.header("Investment Analytics Dashboard") | |
| # Fetch data for visualizations | |
| conn = Connection( | |
| arangoURL=arangoURL, | |
| username=username, | |
| password=password | |
| ) | |
| db_name = dbName | |
| db = conn[db_name] | |
| # Investment Amount Distribution | |
| st.subheader("Investment Distribution") | |
| investment_query = """ | |
| FOR investment IN investments | |
| RETURN investment.investment_amount | |
| """ | |
| investments = db.AQLQuery(investment_query, rawResults=True) | |
| if investments: | |
| st.write("Sample investment data:", investments[:5]) | |
| # Flatten the result in case each element is a list. | |
| flat_investments = [] | |
| for item in investments: | |
| if isinstance(item, list): | |
| flat_investments.extend(item) | |
| else: | |
| flat_investments.append(item) | |
| df_investments = pd.DataFrame({'investment_amount': flat_investments}) | |
| fig = px.histogram( | |
| df_investments, | |
| x="investment_amount", | |
| nbins=20, | |
| labels={"investment_amount": "Investment Amount"}, | |
| title="Distribution of Investment Amounts" | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Shark Activity | |
| st.subheader("Shark Investment Activity") | |
| shark_query = """ | |
| WITH startups | |
| FOR shark IN investors | |
| LET deals = ( | |
| FOR v IN 1..1 OUTBOUND shark investments | |
| RETURN v | |
| ) | |
| RETURN { | |
| shark: shark.name, | |
| deal_count: LENGTH(deals), | |
| total_invested: SUM(deals[*].investment_amount) | |
| } | |
| """ | |
| shark_data = db.AQLQuery(shark_query, rawResults=True) | |
| if shark_data: | |
| df = pd.DataFrame(shark_data) | |
| fig = px.bar( | |
| df, | |
| x="shark", | |
| y=["deal_count", "total_invested"], | |
| title="Shark Investment Activity", | |
| barmode="group" | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Network Visualization | |
| st.subheader("Investor-Startup Network") | |
| graph_query = """ | |
| FOR edge IN investments | |
| RETURN { | |
| source: SPLIT(edge._from, '/')[1], | |
| target: SPLIT(edge._to, '/')[1], | |
| amount: edge.investment_amount | |
| } | |
| """ | |
| edges = db.AQLQuery(graph_query, rawResults=True) | |
| if edges: | |
| G = nx.from_pandas_edgelist( | |
| pd.DataFrame(edges), | |
| source="source", | |
| target="target", | |
| edge_attr="amount" | |
| ) | |
| # Plotly network visualization | |
| pos = nx.spring_layout(G) | |
| edge_x = [] | |
| edge_y = [] | |
| for edge in G.edges(): | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| edge_x.extend([x0, x1, None]) | |
| edge_y.extend([y0, y1, None]) | |
| edge_trace = go.Scatter( | |
| x=edge_x, y=edge_y, | |
| line=dict(width=0.5, color='#888'), | |
| hoverinfo='none', | |
| mode='lines' | |
| ) | |
| node_x = [] | |
| node_y = [] | |
| text = [] | |
| for node in G.nodes(): | |
| x, y = pos[node] | |
| node_x.append(x) | |
| node_y.append(y) | |
| text.append(node) | |
| node_trace = go.Scatter( | |
| x=node_x, y=node_y, | |
| mode='markers+text', | |
| text=text, | |
| textposition="bottom center", | |
| marker=dict( | |
| showscale=True, | |
| colorscale='YlGnBu', | |
| size=10, | |
| color=[], | |
| line_width=2 | |
| ) | |
| ) | |
| fig = go.Figure(data=[edge_trace, node_trace], | |
| layout=go.Layout( | |
| showlegend=False, | |
| hovermode='closest', | |
| margin=dict(b=20, l=5, r=5, t=40), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False) | |
| ) | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) |