Spaces:
Sleeping
Sleeping
cd@bziiit.com commited on
Commit ·
1d77a32
1
Parent(s): 08069bc
feat : Store in local db
Browse files- db/__init__.py +0 -0
- db/db.py +121 -0
- pages/chapter_params.py +16 -9
- pages/chatbot.py +13 -2
db/__init__.py
ADDED
|
File without changes
|
db/db.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
|
| 4 |
+
DATABASE_PATH = "database.db" # Chemin de la base de données
|
| 5 |
+
|
| 6 |
+
# Gestionnaire de contexte pour gérer les connexions
|
| 7 |
+
@contextmanager
|
| 8 |
+
def connect_db():
|
| 9 |
+
conn = sqlite3.connect(DATABASE_PATH)
|
| 10 |
+
try:
|
| 11 |
+
yield conn
|
| 12 |
+
finally:
|
| 13 |
+
conn.commit()
|
| 14 |
+
conn.close()
|
| 15 |
+
|
| 16 |
+
# Classe responsable de la base de données
|
| 17 |
+
class DatabaseHandler:
|
| 18 |
+
def __init__(self, db_path=DATABASE_PATH):
|
| 19 |
+
self.db_path = db_path
|
| 20 |
+
self._initialize_db()
|
| 21 |
+
|
| 22 |
+
def _initialize_db(self):
|
| 23 |
+
with connect_db() as conn:
|
| 24 |
+
cursor = conn.cursor()
|
| 25 |
+
cursor.execute('''CREATE TABLE IF NOT EXISTS prompts (
|
| 26 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 27 |
+
type TEXT NOT NULL CHECK(type IN ('installation', 'difficult')),
|
| 28 |
+
num INTEGER NOT NULL,
|
| 29 |
+
title TEXT NOT NULL,
|
| 30 |
+
prompt TEXT NOT NULL
|
| 31 |
+
);''')
|
| 32 |
+
if not self._is_data_present(cursor):
|
| 33 |
+
self._load_default_data(cursor)
|
| 34 |
+
|
| 35 |
+
def _is_data_present(self, cursor):
|
| 36 |
+
cursor.execute("SELECT COUNT(*) FROM prompts;")
|
| 37 |
+
return cursor.fetchone()[0] > 0
|
| 38 |
+
|
| 39 |
+
def _load_default_data(self, cursor):
|
| 40 |
+
data = [
|
| 41 |
+
{"type": "installation", "num": 1, "title": "Contexte et objectifs", "prompt": "context_objectives"},
|
| 42 |
+
{"type": "installation", "num": 2, "title": "Statut social, juridique et fiscal", "prompt": "social_legal_fiscal_status"},
|
| 43 |
+
{"type": "installation", "num": 3, "title": "Moyens humains", "prompt": "human_resources"},
|
| 44 |
+
{"type": "installation", "num": 4, "title": "Moyens de production", "prompt": "production_resources"},
|
| 45 |
+
{"type": "installation", "num": 5, "title": "Production par atelier", "prompt": "workshop_production"},
|
| 46 |
+
{"type": "difficult", "num": 11, "title": "Contexte et objectifs", "prompt": "difficult context_objectives"},
|
| 47 |
+
{"type": "difficult", "num": 12, "title": "Statut social, juridique et fiscal", "prompt": "difficult social_legal_fiscal_status"},
|
| 48 |
+
{"type": "difficult", "num": 13, "title": "Moyens humains", "prompt": "difficult human_resources"},
|
| 49 |
+
{"type": "difficult", "num": 14, "title": "Moyens de production", "prompt": "difficult production_resources"},
|
| 50 |
+
{"type": "difficult", "num": 15, "title": "Production par atelier", "prompt": "difficult workshop_production"}
|
| 51 |
+
]
|
| 52 |
+
for record in data:
|
| 53 |
+
cursor.execute(
|
| 54 |
+
"INSERT INTO prompts (type, num, title, prompt) VALUES (:type, :num, :title, :prompt);",
|
| 55 |
+
record
|
| 56 |
+
)
|
| 57 |
+
cursor.connection.commit()
|
| 58 |
+
|
| 59 |
+
def _rows_to_dicts(self, rows):
|
| 60 |
+
result = []
|
| 61 |
+
for r in rows:
|
| 62 |
+
result.append({
|
| 63 |
+
"id": r[0],
|
| 64 |
+
"type": r[1],
|
| 65 |
+
"num": r[2],
|
| 66 |
+
"title": r[3],
|
| 67 |
+
"prompt_system": r[4],
|
| 68 |
+
})
|
| 69 |
+
return result
|
| 70 |
+
|
| 71 |
+
def get_prompts(self):
|
| 72 |
+
with connect_db() as conn:
|
| 73 |
+
cursor = conn.cursor()
|
| 74 |
+
cursor.execute("SELECT * FROM prompts;")
|
| 75 |
+
rows = cursor.fetchall()
|
| 76 |
+
return self._rows_to_dicts(rows)
|
| 77 |
+
|
| 78 |
+
def get_prompt_by_filters(self, type_=None, num=None):
|
| 79 |
+
with connect_db() as conn:
|
| 80 |
+
cursor = conn.cursor()
|
| 81 |
+
query = "SELECT * FROM prompts"
|
| 82 |
+
conditions = []
|
| 83 |
+
params = []
|
| 84 |
+
|
| 85 |
+
if type_:
|
| 86 |
+
conditions.append("type = ?")
|
| 87 |
+
params.append(type_)
|
| 88 |
+
if num:
|
| 89 |
+
conditions.append("num = ?")
|
| 90 |
+
params.append(num)
|
| 91 |
+
|
| 92 |
+
if conditions:
|
| 93 |
+
query += " WHERE " + " AND ".join(conditions)
|
| 94 |
+
|
| 95 |
+
cursor.execute(query, params)
|
| 96 |
+
rows = cursor.fetchall()
|
| 97 |
+
return self._rows_to_dicts(rows)
|
| 98 |
+
|
| 99 |
+
def add_prompt(self, type_, num, title, prompt):
|
| 100 |
+
with connect_db() as conn:
|
| 101 |
+
cursor = conn.cursor()
|
| 102 |
+
cursor.execute("INSERT INTO prompts (type, num, title, prompt) VALUES (?, ?, ?);", (type_, num, title, prompt))
|
| 103 |
+
|
| 104 |
+
def update_prompt(self, prompt_id, type_=None, num=None, title=None, prompt=None):
|
| 105 |
+
with connect_db() as conn:
|
| 106 |
+
cursor = conn.cursor()
|
| 107 |
+
query = "UPDATE prompts SET "
|
| 108 |
+
fields = []
|
| 109 |
+
params = []
|
| 110 |
+
if type_:
|
| 111 |
+
fields.append("type = ?")
|
| 112 |
+
params.append(type_)
|
| 113 |
+
if title:
|
| 114 |
+
fields.append("title = ?")
|
| 115 |
+
params.append(title)
|
| 116 |
+
if prompt:
|
| 117 |
+
fields.append("prompt = ?")
|
| 118 |
+
params.append(prompt)
|
| 119 |
+
query += ", ".join(fields) + " WHERE num = ?;"
|
| 120 |
+
params.append(prompt_id)
|
| 121 |
+
cursor.execute(query, params)
|
pages/chapter_params.py
CHANGED
|
@@ -1,22 +1,29 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
def area_change(key, ):
|
| 4 |
new_value = st.session_state[key]
|
| 5 |
-
|
| 6 |
-
key = key[5:]
|
| 7 |
-
chapter_key = f"chapter_{key}"
|
| 8 |
|
| 9 |
-
|
| 10 |
-
st.session_state[chapter_key]["prompt_system"] = new_value
|
| 11 |
|
| 12 |
|
| 13 |
def page():
|
| 14 |
st.subheader("Définissez vos paramètres")
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
parts_sorted = sorted(chapters, key=lambda part: part.get('num', float('inf')))
|
| 21 |
|
| 22 |
# Création de tabs pour chaque 'part' trié
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from db.db import DatabaseHandler
|
| 3 |
+
|
| 4 |
+
# Instanciation de la base de données
|
| 5 |
+
db = DatabaseHandler()
|
| 6 |
|
| 7 |
def area_change(key, ):
|
| 8 |
new_value = st.session_state[key]
|
| 9 |
+
key = key[5:] # remove 'area_' prefix
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
db.update_prompt(key, prompt=new_value)
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def page():
|
| 15 |
st.subheader("Définissez vos paramètres")
|
| 16 |
|
| 17 |
+
type = "Installation"
|
| 18 |
+
# type = st.selectbox("Type de diagnostique ?", ["Installation", "Difficulté"], key="mode")
|
| 19 |
+
if(type == "Installation"):
|
| 20 |
+
diag_type = "installation"
|
| 21 |
+
else:
|
| 22 |
+
diag_type = "difficult"
|
| 23 |
+
|
| 24 |
+
# Chargement de l'enregistrement correspondant aux filtres
|
| 25 |
+
chapters = db.get_prompt_by_filters(type_=diag_type)
|
| 26 |
+
|
| 27 |
parts_sorted = sorted(chapters, key=lambda part: part.get('num', float('inf')))
|
| 28 |
|
| 29 |
# Création de tabs pour chaque 'part' trié
|
pages/chatbot.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import uuid
|
| 3 |
-
import
|
| 4 |
|
| 5 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 6 |
from model import selector
|
| 7 |
from st_copy_to_clipboard import st_copy_to_clipboard
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
chapter_num = 0
|
| 10 |
chapter_session_key = f"chapter_{chapter_num}"
|
| 11 |
|
|
@@ -85,6 +88,14 @@ def page():
|
|
| 85 |
return
|
| 86 |
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
#################
|
| 89 |
# Some controls #
|
| 90 |
#################
|
|
@@ -96,7 +107,7 @@ def page():
|
|
| 96 |
st.session_state[chapter_session_key]["messages"] = [ ]
|
| 97 |
|
| 98 |
if len(st.session_state[chapter_session_key]["messages"]) < 2 :
|
| 99 |
-
st.session_state[chapter_session_key]["messages"] = [ SystemMessage(content=
|
| 100 |
|
| 101 |
|
| 102 |
############
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import uuid
|
| 3 |
+
from db.db import DatabaseHandler
|
| 4 |
|
| 5 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 6 |
from model import selector
|
| 7 |
from st_copy_to_clipboard import st_copy_to_clipboard
|
| 8 |
|
| 9 |
+
# Instanciation de la base de données
|
| 10 |
+
db = DatabaseHandler()
|
| 11 |
+
|
| 12 |
chapter_num = 0
|
| 13 |
chapter_session_key = f"chapter_{chapter_num}"
|
| 14 |
|
|
|
|
| 88 |
return
|
| 89 |
|
| 90 |
|
| 91 |
+
# Chargement de l'enregistrement correspondant aux filtres
|
| 92 |
+
chapterDB = db.get_prompt_by_filters(num=chapter_num)
|
| 93 |
+
if len(chapterDB) == 0:
|
| 94 |
+
st.text("Chapitre non trouvé")
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
chapterDB = chapterDB[0]
|
| 98 |
+
|
| 99 |
#################
|
| 100 |
# Some controls #
|
| 101 |
#################
|
|
|
|
| 107 |
st.session_state[chapter_session_key]["messages"] = [ ]
|
| 108 |
|
| 109 |
if len(st.session_state[chapter_session_key]["messages"]) < 2 :
|
| 110 |
+
st.session_state[chapter_session_key]["messages"] = [ SystemMessage(content=chapterDB['prompt_system']) ]
|
| 111 |
|
| 112 |
|
| 113 |
############
|