cd@bziiit.com commited on
Commit
1d77a32
·
1 Parent(s): 08069bc

feat : Store in local db

Browse files
Files changed (4) hide show
  1. db/__init__.py +0 -0
  2. db/db.py +121 -0
  3. pages/chapter_params.py +16 -9
  4. pages/chatbot.py +13 -2
db/__init__.py ADDED
File without changes
db/db.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from contextlib import contextmanager
3
+
4
+ DATABASE_PATH = "database.db" # Chemin de la base de données
5
+
6
+ # Gestionnaire de contexte pour gérer les connexions
7
+ @contextmanager
8
+ def connect_db():
9
+ conn = sqlite3.connect(DATABASE_PATH)
10
+ try:
11
+ yield conn
12
+ finally:
13
+ conn.commit()
14
+ conn.close()
15
+
16
+ # Classe responsable de la base de données
17
+ class DatabaseHandler:
18
+ def __init__(self, db_path=DATABASE_PATH):
19
+ self.db_path = db_path
20
+ self._initialize_db()
21
+
22
+ def _initialize_db(self):
23
+ with connect_db() as conn:
24
+ cursor = conn.cursor()
25
+ cursor.execute('''CREATE TABLE IF NOT EXISTS prompts (
26
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
27
+ type TEXT NOT NULL CHECK(type IN ('installation', 'difficult')),
28
+ num INTEGER NOT NULL,
29
+ title TEXT NOT NULL,
30
+ prompt TEXT NOT NULL
31
+ );''')
32
+ if not self._is_data_present(cursor):
33
+ self._load_default_data(cursor)
34
+
35
+ def _is_data_present(self, cursor):
36
+ cursor.execute("SELECT COUNT(*) FROM prompts;")
37
+ return cursor.fetchone()[0] > 0
38
+
39
+ def _load_default_data(self, cursor):
40
+ data = [
41
+ {"type": "installation", "num": 1, "title": "Contexte et objectifs", "prompt": "context_objectives"},
42
+ {"type": "installation", "num": 2, "title": "Statut social, juridique et fiscal", "prompt": "social_legal_fiscal_status"},
43
+ {"type": "installation", "num": 3, "title": "Moyens humains", "prompt": "human_resources"},
44
+ {"type": "installation", "num": 4, "title": "Moyens de production", "prompt": "production_resources"},
45
+ {"type": "installation", "num": 5, "title": "Production par atelier", "prompt": "workshop_production"},
46
+ {"type": "difficult", "num": 11, "title": "Contexte et objectifs", "prompt": "difficult context_objectives"},
47
+ {"type": "difficult", "num": 12, "title": "Statut social, juridique et fiscal", "prompt": "difficult social_legal_fiscal_status"},
48
+ {"type": "difficult", "num": 13, "title": "Moyens humains", "prompt": "difficult human_resources"},
49
+ {"type": "difficult", "num": 14, "title": "Moyens de production", "prompt": "difficult production_resources"},
50
+ {"type": "difficult", "num": 15, "title": "Production par atelier", "prompt": "difficult workshop_production"}
51
+ ]
52
+ for record in data:
53
+ cursor.execute(
54
+ "INSERT INTO prompts (type, num, title, prompt) VALUES (:type, :num, :title, :prompt);",
55
+ record
56
+ )
57
+ cursor.connection.commit()
58
+
59
+ def _rows_to_dicts(self, rows):
60
+ result = []
61
+ for r in rows:
62
+ result.append({
63
+ "id": r[0],
64
+ "type": r[1],
65
+ "num": r[2],
66
+ "title": r[3],
67
+ "prompt_system": r[4],
68
+ })
69
+ return result
70
+
71
+ def get_prompts(self):
72
+ with connect_db() as conn:
73
+ cursor = conn.cursor()
74
+ cursor.execute("SELECT * FROM prompts;")
75
+ rows = cursor.fetchall()
76
+ return self._rows_to_dicts(rows)
77
+
78
+ def get_prompt_by_filters(self, type_=None, num=None):
79
+ with connect_db() as conn:
80
+ cursor = conn.cursor()
81
+ query = "SELECT * FROM prompts"
82
+ conditions = []
83
+ params = []
84
+
85
+ if type_:
86
+ conditions.append("type = ?")
87
+ params.append(type_)
88
+ if num:
89
+ conditions.append("num = ?")
90
+ params.append(num)
91
+
92
+ if conditions:
93
+ query += " WHERE " + " AND ".join(conditions)
94
+
95
+ cursor.execute(query, params)
96
+ rows = cursor.fetchall()
97
+ return self._rows_to_dicts(rows)
98
+
99
+ def add_prompt(self, type_, num, title, prompt):
100
+ with connect_db() as conn:
101
+ cursor = conn.cursor()
102
+ cursor.execute("INSERT INTO prompts (type, num, title, prompt) VALUES (?, ?, ?);", (type_, num, title, prompt))
103
+
104
+ def update_prompt(self, prompt_id, type_=None, num=None, title=None, prompt=None):
105
+ with connect_db() as conn:
106
+ cursor = conn.cursor()
107
+ query = "UPDATE prompts SET "
108
+ fields = []
109
+ params = []
110
+ if type_:
111
+ fields.append("type = ?")
112
+ params.append(type_)
113
+ if title:
114
+ fields.append("title = ?")
115
+ params.append(title)
116
+ if prompt:
117
+ fields.append("prompt = ?")
118
+ params.append(prompt)
119
+ query += ", ".join(fields) + " WHERE num = ?;"
120
+ params.append(prompt_id)
121
+ cursor.execute(query, params)
pages/chapter_params.py CHANGED
@@ -1,22 +1,29 @@
1
  import streamlit as st
 
 
 
 
2
 
3
  def area_change(key, ):
4
  new_value = st.session_state[key]
5
-
6
- key = key[5:]
7
- chapter_key = f"chapter_{key}"
8
 
9
- if chapter_key in st.session_state:
10
- st.session_state[chapter_key]["prompt_system"] = new_value
11
 
12
 
13
  def page():
14
  st.subheader("Définissez vos paramètres")
15
 
16
- chapters = []
17
- for chapter in st.session_state["chapters"]:
18
- chapters.append(st.session_state[f"chapter_{chapter['num']}"])
19
-
 
 
 
 
 
 
20
  parts_sorted = sorted(chapters, key=lambda part: part.get('num', float('inf')))
21
 
22
  # Création de tabs pour chaque 'part' trié
 
1
  import streamlit as st
2
+ from db.db import DatabaseHandler
3
+
4
+ # Instanciation de la base de données
5
+ db = DatabaseHandler()
6
 
7
  def area_change(key, ):
8
  new_value = st.session_state[key]
9
+ key = key[5:] # remove 'area_' prefix
 
 
10
 
11
+ db.update_prompt(key, prompt=new_value)
 
12
 
13
 
14
  def page():
15
  st.subheader("Définissez vos paramètres")
16
 
17
+ type = "Installation"
18
+ # type = st.selectbox("Type de diagnostique ?", ["Installation", "Difficulté"], key="mode")
19
+ if(type == "Installation"):
20
+ diag_type = "installation"
21
+ else:
22
+ diag_type = "difficult"
23
+
24
+ # Chargement de l'enregistrement correspondant aux filtres
25
+ chapters = db.get_prompt_by_filters(type_=diag_type)
26
+
27
  parts_sorted = sorted(chapters, key=lambda part: part.get('num', float('inf')))
28
 
29
  # Création de tabs pour chaque 'part' trié
pages/chatbot.py CHANGED
@@ -1,11 +1,14 @@
1
  import streamlit as st
2
  import uuid
3
- import os
4
 
5
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
6
  from model import selector
7
  from st_copy_to_clipboard import st_copy_to_clipboard
8
 
 
 
 
9
  chapter_num = 0
10
  chapter_session_key = f"chapter_{chapter_num}"
11
 
@@ -85,6 +88,14 @@ def page():
85
  return
86
 
87
 
 
 
 
 
 
 
 
 
88
  #################
89
  # Some controls #
90
  #################
@@ -96,7 +107,7 @@ def page():
96
  st.session_state[chapter_session_key]["messages"] = [ ]
97
 
98
  if len(st.session_state[chapter_session_key]["messages"]) < 2 :
99
- st.session_state[chapter_session_key]["messages"] = [ SystemMessage(content=chapter["prompt_system"]) ]
100
 
101
 
102
  ############
 
1
  import streamlit as st
2
  import uuid
3
+ from db.db import DatabaseHandler
4
 
5
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
6
  from model import selector
7
  from st_copy_to_clipboard import st_copy_to_clipboard
8
 
9
+ # Instanciation de la base de données
10
+ db = DatabaseHandler()
11
+
12
  chapter_num = 0
13
  chapter_session_key = f"chapter_{chapter_num}"
14
 
 
88
  return
89
 
90
 
91
+ # Chargement de l'enregistrement correspondant aux filtres
92
+ chapterDB = db.get_prompt_by_filters(num=chapter_num)
93
+ if len(chapterDB) == 0:
94
+ st.text("Chapitre non trouvé")
95
+ return
96
+
97
+ chapterDB = chapterDB[0]
98
+
99
  #################
100
  # Some controls #
101
  #################
 
107
  st.session_state[chapter_session_key]["messages"] = [ ]
108
 
109
  if len(st.session_state[chapter_session_key]["messages"]) < 2 :
110
+ st.session_state[chapter_session_key]["messages"] = [ SystemMessage(content=chapterDB['prompt_system']) ]
111
 
112
 
113
  ############