twissamodi commited on
Commit
be16ac4
·
1 Parent(s): 243d1a2

handle multiple logins

Browse files
Files changed (5) hide show
  1. app.py +63 -7
  2. audio_handler.py +7 -5
  3. chat_handler.py +8 -1
  4. requirements.txt +2 -1
  5. user_data.py +88 -0
app.py CHANGED
@@ -5,9 +5,11 @@ from tools import MedicalTools
5
  from graph_setup import GraphSetup
6
  from chat_handler import ChatHandler
7
  from audio_handler import AudioHandler
8
-
 
9
 
10
  load_dotenv(override=True)
 
11
 
12
  rag = RAG_Setup()
13
  medical_tools = MedicalTools(rag)
@@ -18,20 +20,75 @@ chat_handler = ChatHandler(graph, rag)
18
  audio_handler = AudioHandler()
19
 
20
 
21
- def transcribe_audio_wrapper(audio, current_text, file_input, message_history):
 
 
 
 
 
 
 
22
  return audio_handler.transcribe_audio(
23
  audio,
24
  current_text,
25
  file_input,
26
  message_history,
 
 
27
  chat_handler.chat
28
  )
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  with gr.Blocks(title="Medical Assistant") as demo:
 
 
 
32
  gr.Markdown("# 🏥 Medical Assistant")
33
  gr.Markdown("Ask questions using text, voice, or upload medical documents")
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  chatbot = gr.Chatbot(label="Conversation", height=400)
36
 
37
  with gr.Row():
@@ -62,22 +119,21 @@ with gr.Blocks(title="Medical Assistant") as demo:
62
 
63
  submit_btn.click(
64
  chat_handler.chat,
65
- inputs=[text_input, file_input, chatbot],
66
- outputs=[chatbot, text_input, file_input]
67
  )
68
 
69
  text_input.submit(
70
  chat_handler.chat,
71
- inputs=[text_input, file_input, chatbot],
72
  outputs=[chatbot, text_input, file_input]
73
  )
74
 
75
  audio_input.change(
76
  transcribe_audio_wrapper,
77
- inputs=[audio_input, text_input, file_input, chatbot],
78
  outputs=[chatbot, text_input, audio_input, file_input]
79
  )
80
 
81
-
82
  if __name__ == "__main__":
83
  demo.launch(share=True)
 
5
  from graph_setup import GraphSetup
6
  from chat_handler import ChatHandler
7
  from audio_handler import AudioHandler
8
+ from user_data import initialize_db, add_user, create_session
9
+ import uuid
10
 
11
  load_dotenv(override=True)
12
+ initialize_db()
13
 
14
  rag = RAG_Setup()
15
  medical_tools = MedicalTools(rag)
 
20
  audio_handler = AudioHandler()
21
 
22
 
23
+ def transcribe_audio_wrapper(audio, current_text, file_input, message_history, user_state, session_state):
24
+ if not user_state or not session_state:
25
+ warning = {
26
+ "role": "assistant",
27
+ "content": "Please log in and start a session before using voice input."
28
+ }
29
+ return message_history + [warning], current_text, None, file_input, user_state, session_state
30
+
31
  return audio_handler.transcribe_audio(
32
  audio,
33
  current_text,
34
  file_input,
35
  message_history,
36
+ user_state,
37
+ session_state,
38
  chat_handler.chat
39
  )
40
 
41
+ def handle_login(user_identifier):
42
+ if not user_identifier or not user_identifier.strip():
43
+ return (
44
+ "Please enter a user name or email.",
45
+ None,
46
+ None
47
+ )
48
+
49
+ user_id = user_identifier.strip().lower()
50
+ session_id = str(uuid.uuid4())
51
+
52
+ add_user(user_id, user_identifier)
53
+ create_session(user_id, session_id)
54
+
55
+ session_md = f"**Active user:** {user_id}<br>**Session:** {session_id}"
56
+
57
+ return (
58
+ session_md,
59
+ {"user_id": user_id, "name": user_identifier},
60
+ {"session_id": session_id}
61
+ )
62
+
63
+
64
+ def handle_logout():
65
+ return "No active session.", None, None
66
 
67
  with gr.Blocks(title="Medical Assistant") as demo:
68
+ user_state = gr.State(value=None)
69
+ session_state = gr.State(value=None)
70
+
71
  gr.Markdown("# 🏥 Medical Assistant")
72
  gr.Markdown("Ask questions using text, voice, or upload medical documents")
73
 
74
+ with gr.Accordion("User Login", open=True):
75
+ user_input = gr.Textbox(label="Enter email or username", placeholder="name@example.com")
76
+ with gr.Row():
77
+ login_button = gr.Button("Start Session", variant="primary")
78
+ logout_button = gr.Button("End Session", variant="stop")
79
+ session_display = gr.Markdown("No active session.")
80
+
81
+ login_button.click(
82
+ handle_login,
83
+ inputs=[user_input],
84
+ outputs=[session_display, user_state, session_state],
85
+ )
86
+
87
+ logout_button.click(
88
+ handle_logout,
89
+ outputs=[session_display, user_state, session_state],
90
+ )
91
+
92
  chatbot = gr.Chatbot(label="Conversation", height=400)
93
 
94
  with gr.Row():
 
119
 
120
  submit_btn.click(
121
  chat_handler.chat,
122
+ inputs=[text_input, file_input, chatbot, user_state, session_state],
123
+ outputs=[chatbot, text_input, file_input],
124
  )
125
 
126
  text_input.submit(
127
  chat_handler.chat,
128
+ inputs=[text_input, file_input, chatbot, user_state, session_state],
129
  outputs=[chatbot, text_input, file_input]
130
  )
131
 
132
  audio_input.change(
133
  transcribe_audio_wrapper,
134
+ inputs=[audio_input, text_input, file_input, chatbot, user_state, session_state],
135
  outputs=[chatbot, text_input, audio_input, file_input]
136
  )
137
 
 
138
  if __name__ == "__main__":
139
  demo.launch(share=True)
audio_handler.py CHANGED
@@ -5,16 +5,18 @@ class AudioHandler:
5
  def __init__(self):
6
  self.transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-small")
7
 
8
- def transcribe_audio(self, audio, current_text, file_input, message_history, chat_func):
9
  if audio is None:
10
- return message_history, current_text, None, file_input
11
 
12
  transcript = self.transcriber(audio)["text"].strip()
13
 
14
- updated_history, cleared_text, cleared_file = chat_func(
15
  transcript,
16
  file_input,
17
- message_history
 
 
18
  )
19
 
20
- return updated_history, current_text, None, cleared_file
 
5
  def __init__(self):
6
  self.transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-small")
7
 
8
+ def transcribe_audio(self, audio, current_text, file_input, message_history, user_state, session_state, chat_func):
9
  if audio is None:
10
+ return message_history, current_text, None, file_input, user_state, session_state
11
 
12
  transcript = self.transcriber(audio)["text"].strip()
13
 
14
+ updated_history, cleared_text, cleared_file, user_state, session_state = chat_func(
15
  transcript,
16
  file_input,
17
+ message_history,
18
+ user_state,
19
+ session_state
20
  )
21
 
22
+ return updated_history, cleared_text, None, cleared_file, user_state, session_state
chat_handler.py CHANGED
@@ -11,7 +11,14 @@ class ChatHandler:
11
  self.session_id = str(uuid.uuid4())
12
  print(self.session_id)
13
 
14
- def chat(self, user_message, uploaded_file, message_history):
 
 
 
 
 
 
 
15
  user_query_parts = []
16
  try:
17
  if user_message and user_message.strip():
 
11
  self.session_id = str(uuid.uuid4())
12
  print(self.session_id)
13
 
14
+ def chat(self, user_message, uploaded_file, message_history, user_state, session_state):
15
+ if not user_state or not session_state:
16
+ warning = {
17
+ "role": "assistant",
18
+ "content": "Please log in and start a session before chatting."
19
+ }
20
+ return message_history + [warning], user_message, uploaded_file, user_state, session_state
21
+
22
  user_query_parts = []
23
  try:
24
  if user_message and user_message.strip():
requirements.txt CHANGED
@@ -6,4 +6,5 @@ langgraph-checkpoint-sqlite
6
  langchain
7
  transformers
8
  sentence-transformers
9
- torch
 
 
6
  langchain
7
  transformers
8
  sentence-transformers
9
+ torch
10
+ pypdf
user_data.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from pathlib import Path
3
+
4
+ DB_PATH = Path("data/user_data.db")
5
+
6
+ def get_connection():
7
+ DB_PATH.parent.mkdir(parents=True, exist_ok=True)
8
+ conn = sqlite3.connect(DB_PATH)
9
+ conn.row_factory = sqlite3.Row
10
+ return conn
11
+
12
+ def initialize_db():
13
+ conn = get_connection()
14
+ cursor = conn.cursor()
15
+ cursor.executescript("""
16
+ CREATE TABLE IF NOT EXISTS users (
17
+ id TEXT PRIMARY KEY,
18
+ name TEXT NOT NULL,
19
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
20
+ );
21
+
22
+ CREATE TABLE IF NOT EXISTS sessions (
23
+ id TEXT PRIMARY KEY,
24
+ user_id TEXT NOT NULL,
25
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
26
+ FOREIGN KEY (user_id) REFERENCES users(id)
27
+ );
28
+
29
+ CREATE TABLE IF NOT EXISTS document_classifications (
30
+ file_hash TEXT PRIMARY KEY,
31
+ doc_type TEXT NOT NULL,
32
+ classified_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
33
+ );
34
+ """)
35
+ conn.commit()
36
+ conn.close()
37
+
38
+ def add_user(user_id, name):
39
+ conn = get_connection()
40
+ cursor = conn.cursor()
41
+ cursor.execute("""INSERT INTO users (id, name) VALUES (?, ?)
42
+ ON CONFLICT(id) DO UPDATE SET name = excluded.name""", (user_id, name))
43
+ conn.commit()
44
+ conn.close()
45
+
46
+ def create_session(user_id, session_id):
47
+ conn = get_connection()
48
+ cursor = conn.cursor()
49
+ cursor.execute("INSERT INTO sessions (id, user_id) VALUES (?, ?)", (session_id, user_id))
50
+ conn.commit()
51
+ conn.close()
52
+
53
+ def user_exists(user_id):
54
+ conn = get_connection()
55
+ cursor = conn.cursor()
56
+ cursor.execute("SELECT 1 FROM users WHERE id = ?", (user_id,))
57
+ exists = cursor.fetchone() is not None
58
+ conn.close()
59
+ return exists
60
+
61
+
62
+ def get_document_label(file_hash: str):
63
+ conn = get_connection()
64
+ cursor = conn.cursor()
65
+ cursor.execute(
66
+ "SELECT doc_type FROM document_classifications WHERE file_hash = ?",
67
+ (file_hash,)
68
+ )
69
+ row = cursor.fetchone()
70
+ conn.close()
71
+ return row["doc_type"] if row else None
72
+
73
+
74
+ def save_document_label(file_hash: str, doc_type: str):
75
+ conn = get_connection()
76
+ cursor = conn.cursor()
77
+ cursor.execute(
78
+ """
79
+ INSERT INTO document_classifications (file_hash, doc_type)
80
+ VALUES (?, ?)
81
+ ON CONFLICT(file_hash) DO UPDATE SET
82
+ doc_type = excluded.doc_type,
83
+ classified_at = CURRENT_TIMESTAMP
84
+ """,
85
+ (file_hash, doc_type)
86
+ )
87
+ conn.commit()
88
+ conn.close()