YuraBodnar commited on
Commit
7df1fda
·
verified ·
1 Parent(s): abfad7f

Upload 14 files

Browse files

import our solution

.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ weather.db filter=lfs diff=lfs merge=lfs -text
agent.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ from langgraph.graph import StateGraph, END
3
+ from pydantic import BaseModel
4
+ from langchain_openai import ChatOpenAI
5
+ from dotenv import load_dotenv
6
+ from os import getenv
7
+ from utils import AgentState
8
+ from agents_utils import make_generate_sql_node, run_sql_node, make_generate_summary_node
9
+
10
+ load_dotenv()
11
+
12
+
13
+ llm = ChatOpenAI(
14
+ base_url="https://openrouter.ai/api/v1",
15
+ api_key=getenv("OPENROUTER_API_KEY"),
16
+ model="openai/gpt-oss-20b:free"
17
+ )
18
+
19
+ generate_sql_node = make_generate_sql_node(llm)
20
+ generate_summary_node = make_generate_summary_node(llm)
21
+
22
+ graph = StateGraph(AgentState)
23
+
24
+ graph.add_node("generate_sql", generate_sql_node)
25
+ graph.add_node("run_sql", run_sql_node)
26
+ graph.add_node("summary", generate_summary_node)
27
+
28
+ graph.set_entry_point("generate_sql")
29
+ graph.add_edge("generate_sql", "run_sql")
30
+ graph.add_edge("run_sql", "summary")
31
+ graph.set_finish_point("summary")
32
+
33
+ agent = graph.compile()
agents_utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import AgentState, SQLGenerationResult
2
+ from prompt_generation import render_sql_generation_prompts, render_summary_generation_prompts
3
+ from typing import Dict
4
+ import sqlite3
5
+ from config import DB_PATH
6
+
7
+
8
+ def make_generate_sql_node(llm):
9
+ def generate_sql_node(state: AgentState) -> Dict:
10
+ system_prompt, user_prompt = render_sql_generation_prompts(state.question)
11
+
12
+ structured_llm = llm.with_structured_output(SQLGenerationResult)
13
+ result = structured_llm.invoke([
14
+ {"role": "system", "content": system_prompt},
15
+ {"role": "user", "content": user_prompt},
16
+ ])
17
+
18
+ return {
19
+ "sql_query": result.sql_query,
20
+ "reasoning": result.reasoning,
21
+ }
22
+ return generate_sql_node
23
+
24
+ def make_generate_summary_node(llm):
25
+ def generate_summary_node(state: AgentState) -> Dict:
26
+ system_prompt, user_prompt = render_summary_generation_prompts(
27
+ question=state.question,
28
+ sql_query=state.sql_query,
29
+ rows=state.rows
30
+ )
31
+
32
+ result = llm.invoke([
33
+ {"role": "system", "content": system_prompt},
34
+ {"role": "user", "content": user_prompt}
35
+ ])
36
+
37
+ return {
38
+ "answer": result.content,
39
+ }
40
+
41
+ return generate_summary_node
42
+
43
+ def run_sql_node(state: AgentState) -> Dict:
44
+ if not state.sql_query:
45
+ return {"error": "SQL query was not generated"}
46
+
47
+ try:
48
+ conn = sqlite3.connect(DB_PATH)
49
+ conn.row_factory = sqlite3.Row
50
+ cursor = conn.cursor()
51
+
52
+ cursor.execute(state.sql_query)
53
+ rows = [dict(row) for row in cursor.fetchall()]
54
+
55
+ if len(rows) > 50:
56
+ rows = rows[:50]
57
+
58
+ conn.close()
59
+
60
+ return {
61
+ "rows": rows,
62
+ "error": None
63
+ }
64
+
65
+ except Exception as e:
66
+ return {
67
+ "error": str(e),
68
+ "rows": None
69
+ }
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from agent import agent, AgentState
4
+
5
+ # =========================================
6
+ # STREAMLIT SETTINGS
7
+ # =========================================
8
+ st.set_page_config(page_title="Weather AI Assistant", page_icon="🌤️", layout="wide")
9
+
10
+ # Custom CSS
11
+ st.markdown("""
12
+ <style>
13
+ body { background-color: #f0f4ff; }
14
+ .chat-message {
15
+ padding: 12px; border-radius: 12px; margin-bottom: 12px;
16
+ max-width: 80%; line-height: 1.5;
17
+ }
18
+ .user-msg {
19
+ background: #e3edff; color: #0f1e46;
20
+ align-self: flex-end; text-align: right; margin-left: auto;
21
+ }
22
+ .assistant-msg {
23
+ background: #d5e8ff; color: #0b1b33;
24
+ align-self: flex-start; margin-right: auto;
25
+ }
26
+ .chat-container {
27
+ display: flex; flex-direction: column; gap: 10px;
28
+ }
29
+ </style>
30
+ """, unsafe_allow_html=True)
31
+
32
+ # =========================================
33
+ # SESSION STATE
34
+ # =========================================
35
+ if "messages" not in st.session_state:
36
+ st.session_state.messages = []
37
+ if "last_details" not in st.session_state:
38
+ st.session_state.last_details = None
39
+
40
+ # =========================================
41
+ # MAIN TITLE
42
+ # =========================================
43
+ st.title("🌤️ Weather Data Chat Assistant")
44
+ st.write("Ask questions about weather data — I will generate SQL, run it, and answer.")
45
+
46
+ # =========================================
47
+ # CHAT MESSAGES RENDER
48
+ # =========================================
49
+ # Спочатку малюємо історію чату
50
+ st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
51
+ for role, msg in st.session_state.messages:
52
+ if role == "user":
53
+ st.markdown(f"<div class='chat-message user-msg'><b>You:</b> {msg}</div>", unsafe_allow_html=True)
54
+ else:
55
+ st.markdown(f"<div class='chat-message assistant-msg'><b>Assistant:</b> {msg}</div>", unsafe_allow_html=True)
56
+ st.markdown("</div>", unsafe_allow_html=True)
57
+
58
+ # =========================================
59
+ # USER INPUT & LOGIC
60
+ # =========================================
61
+ user_input = st.chat_input("Type your question here...")
62
+
63
+ if user_input:
64
+ st.session_state.messages.append(("user", user_input))
65
+ st.rerun()
66
+
67
+ if st.session_state.messages and st.session_state.messages[-1][0] == "user":
68
+ last_user_msg = st.session_state.messages[-1][1]
69
+
70
+ with st.spinner("Thinking and querying database..."):
71
+ try:
72
+ raw_state = agent.invoke({"question": last_user_msg})
73
+
74
+ answer = raw_state.get("answer", "No answer generated.")
75
+ sql_query = raw_state.get("sql_query")
76
+ rows = raw_state.get("rows")
77
+ reasoning = raw_state.get("reasoning")
78
+
79
+ st.session_state.messages.append(("assistant", answer))
80
+
81
+ st.session_state.last_details = {
82
+ "sql": sql_query,
83
+ "rows": rows,
84
+ "reasoning": reasoning
85
+ }
86
+
87
+ st.rerun()
88
+
89
+ except Exception as e:
90
+ st.session_state.messages.append(("assistant", f"❌ Error: {e}"))
91
+ st.rerun()
92
+
93
+ # =========================================
94
+ # DEBUG / DETAILS SECTION
95
+ # =========================================
96
+ if st.session_state.last_details:
97
+ with st.expander("🔍 See Technical Details (SQL & Data)", expanded=False):
98
+ details = st.session_state.last_details
99
+
100
+ if details["reasoning"]:
101
+ st.write("**Reasoning:**")
102
+ st.info(details["reasoning"])
103
+
104
+ if details["sql"]:
105
+ st.write("**Generated SQL:**")
106
+ st.code(details["sql"], language="sql")
107
+
108
+ if details["rows"]:
109
+ st.write(f"**Data Found ({len(details['rows'])} rows):**")
110
+ df = pd.DataFrame(details["rows"])
111
+ st.dataframe(df)
112
+ else:
113
+ st.warning("No data returned from SQL.")
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
4
+ DB_PATH = 'weather.db'
5
+ MODEL = "openai/gpt-oss-20b:free"
6
+
7
+ TABLE_NAME = "weather_daily"
8
+ START_DATE = "1980-01-01T000000Z"
9
+ END_DATE = "2019-12-31T230000Z"
10
+
11
+ COLUMNS_TO_KEEP = [
12
+ "utc_timestamp",
13
+
14
+ "AT_temperature",
15
+ "AT_radiation_direct_horizontal",
16
+ "AT_radiation_diffuse_horizontal",
17
+
18
+ "BE_temperature",
19
+ "BE_radiation_direct_horizontal",
20
+ "BE_radiation_diffuse_horizontal",
21
+
22
+ "BG_temperature",
23
+ "BG_radiation_direct_horizontal",
24
+ "BG_radiation_diffuse_horizontal",
25
+
26
+ "CH_temperature",
27
+ "CH_radiation_direct_horizontal",
28
+ "CH_radiation_diffuse_horizontal",
29
+
30
+ "CZ_temperature",
31
+ "CZ_radiation_direct_horizontal",
32
+ "CZ_radiation_diffuse_horizontal",
33
+ ]
database.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pandas as pd
3
+ from config import DB_PATH, COLUMNS_TO_KEEP, TABLE_NAME
4
+
5
+ def init_db_from_csv(csv_path: str = "weather_data.csv"):
6
+ conn = sqlite3.connect(DB_PATH)
7
+
8
+ df = pd.read_csv(csv_path, usecols=COLUMNS_TO_KEEP)
9
+ df.to_sql(TABLE_NAME, conn, if_exists="replace", index=False)
10
+
11
+ conn.close()
12
+
13
+ # init_db_from_csv()
14
+
15
+ def get_connection():
16
+ return sqlite3.connect(DB_PATH)
prompt_generation.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from jinja2 import Environment, FileSystemLoader, select_autoescape, StrictUndefined
2
+ from config import TABLE_NAME, START_DATE, END_DATE
3
+ from typing import List, Dict
4
+
5
+ env = Environment(
6
+ loader=FileSystemLoader("templates"),
7
+ autoescape=select_autoescape(disabled_extensions=("jinja2",)),
8
+ undefined=StrictUndefined,
9
+ trim_blocks=True,
10
+ lstrip_blocks=True,
11
+ )
12
+
13
+ system_sql_template = env.get_template("system_prompts/system_prompt_sql_generation.jinja2")
14
+ user_sql_template = env.get_template("user_prompts/user_prompt_sql_generation.jinja2")
15
+
16
+ system_summary_template = env.get_template("system_prompts/system_prompt_summary_generation.jinja2")
17
+ user_summary_template = env.get_template("user_prompts/user_prompt_summary_generation.jinja2")
18
+
19
+ def render_sql_generation_prompts(question: str) -> tuple[str, str]:
20
+ system_prompt = system_sql_template.render(
21
+ table_name=TABLE_NAME,
22
+ start_date=START_DATE,
23
+ end_date=END_DATE,
24
+ )
25
+ user_prompt = user_sql_template.render(
26
+ question=question
27
+ )
28
+ return system_prompt, user_prompt
29
+
30
+ def render_summary_generation_prompts(question: str, sql_query: str, rows: List[Dict]) -> tuple[str, str]:
31
+ system_prompt = system_summary_template.render()
32
+ user_prompt = user_summary_template.render(
33
+ question=question,
34
+ sql_query=sql_query,
35
+ rows=rows
36
+ )
37
+ return system_prompt, user_prompt
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ openai
4
+ langchain
5
+ langchain-core
6
+ langchain-openai
7
+ langgraph
8
+ langgraph-checkpoint
9
+ langgraph-prebuilt
10
+ python-dotenv
11
+ tiktoken
templates/system_prompts/system_prompt_sql_generation.jinja2 ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an SQL generator for an SQLite weather database.
2
+
3
+ ** Database Information **
4
+ * General Information *
5
+ Table name: {{ table_name }}
6
+
7
+ * Available columns *
8
+ utc_timestamp :
9
+ Column_type: TEXT ISO8601.
10
+ Description: Start of interval in UTC (format YYYY-MM-DDTHHMMSSZ, unique per record).
11
+
12
+ AT_temperature :
13
+ Column_type: REAL.
14
+ Description: Average temperature for Austria (AT) in °C.
15
+ AT_radiation_direct_horizontal :
16
+ Column_type: REAL.
17
+ Description: Direct horizontal solar radiation for Austria (AT) in W/m².
18
+ AT_radiation_diffuse_horizontal :
19
+ Column_type: REAL.
20
+ Description: Diffuse horizontal solar radiation for Austria (AT) in W/m².
21
+
22
+ BE_temperature :
23
+ Column_type: REAL.
24
+ Description: Average temperature for Belgium (BE) in °C.
25
+ BE_radiation_direct_horizontal :
26
+ Column_type: REAL.
27
+ Description: Direct horizontal solar radiation for Belgium (BE) in W/m².
28
+ BE_radiation_diffuse_horizontal :
29
+ Column_type: REAL.
30
+ Description: Diffuse horizontal solar radiation for Belgium (BE) in W/m².
31
+
32
+ BG_temperature :
33
+ Column_type: REAL.
34
+ Description: Average temperature for Bulgaria (BG) in °C.
35
+ BG_radiation_direct_horizontal :
36
+ Column_type: REAL.
37
+ Description: Direct horizontal solar radiation for Bulgaria (BG) in W/m².
38
+ BG_radiation_diffuse_horizontal :
39
+ Column_type: REAL.
40
+ Description: Diffuse horizontal solar radiation for Bulgaria (BG) in W/m².
41
+
42
+ CH_temperature :
43
+ Column_type: REAL.
44
+ Description: Average temperature for Switzerland (CH) in °C.
45
+ CH_radiation_direct_horizontal :
46
+ Column_type: REAL.
47
+ Description: Direct horizontal solar radiation for Switzerland (CH) in W/m².
48
+ CH_radiation_diffuse_horizontal :
49
+ Column_type: REAL.
50
+ Description: Diffuse horizontal solar radiation for Switzerland (CH) in W/m².
51
+
52
+ CZ_temperature :
53
+ Column_type: REAL.
54
+ Description: Average temperature for Czechia (CZ) in °C.
55
+ CZ_radiation_direct_horizontal :
56
+ Column_type: REAL.
57
+ Description: Direct horizontal solar radiation for Czechia (CZ) in W/m².
58
+ CZ_radiation_diffuse_horizontal :
59
+ Column_type: REAL.
60
+ Description: Diffuse horizontal solar radiation for Czechia (CZ) in W/m².
61
+
62
+ ** SQL Query rules generation **
63
+ - Return only SQL.
64
+ - Ensure date filters stay within {{ start_date }} and {{ end_date }}.
65
+ - Do not invent column names that are not listed above.
templates/system_prompts/system_prompt_summary_generation.jinja2 ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a data analysis assistant. Your job is to create a clear and friendly summary
2
+ for the user based on the provided data.
3
+
4
+ Guidelines:
5
+ - Avoid SQL terminology
6
+ - Do not mention rows, tables, or SQL queries
7
+ - Explain insights in simple natural language
8
+ - Be concise (3–6 sentences)
9
+ - If the dataset is small, interpret values directly
10
+ - If the dataset contains statistics like averages, describe them clearly
11
+ - If data is missing, say that no information is available
templates/user_prompts/user_prompt_sql_generation.jinja2 ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ User question:
2
+ {{ question }}
templates/user_prompts/user_prompt_summary_generation.jinja2 ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ The user asked the following question:
2
+ {{ question }}
3
+
4
+ The generated SQL query:
5
+ {{ sql_query }}
6
+
7
+ Here is the data retrieved from the database :
8
+ {{ rows }}
9
+
10
+ Write a short summary in natural language describing the result.
test.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ from langgraph.graph import StateGraph, END
3
+ from pydantic import BaseModel
4
+ from langchain_openai import ChatOpenAI
5
+ from dotenv import load_dotenv
6
+ import os
7
+ from os import getenv
8
+ from prompt_generation import render_sql_generation_prompts
9
+ from utils import SQLGenerationResult
10
+
11
+ load_dotenv()
12
+
13
+
14
+ llm = ChatOpenAI(
15
+ base_url="https://openrouter.ai/api/v1",
16
+ api_key=getenv("OPENROUTER_API_KEY"),
17
+ model="openai/gpt-oss-20b:free"
18
+ )
19
+
20
+ response = llm.invoke("Generate me randon SQL query")
21
+ print(response.content)
utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Dict
3
+
4
+ class SQLGenerationResult(BaseModel):
5
+ sql_query: str = Field(..., description="SQL query to execute")
6
+ reasoning: Optional[str] = Field(None, description="Optional explanation of the query")
7
+
8
+ class AgentState(BaseModel):
9
+ question: str
10
+ sql_query: Optional[str] = None
11
+ reasoning: Optional[str] = None
12
+ rows: Optional[List[Dict]] = None
13
+ answer: Optional[str] = None
14
+ error: Optional[str] = None
weather.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:accea05ffe97779e3b73b8b7e0067c96178c28658fc007556b8e376fe819f29c
3
+ size 45793280