Spaces:
Build error
Build error
| from datetime import datetime | |
| from core.pineconeqa import PineconeQA | |
| import gradio as gr | |
| from config import get_settings | |
| from openai import OpenAI | |
| from utils.models import DatabaseManager | |
| import json | |
| import hashlib | |
| import tempfile | |
| import os | |
| class MedicalChatbot: | |
| def __init__(self): | |
| self.settings = get_settings() | |
| self.qa_system = PineconeQA( | |
| pinecone_api_key=self.settings.PINECONE_API_KEY, | |
| openai_api_key=self.settings.OPENAI_API_KEY, | |
| index_name=self.settings.INDEX_NAME | |
| ) | |
| self.client = OpenAI(api_key=self.settings.OPENAI_API_KEY) | |
| self.db = DatabaseManager() | |
| self.current_doctor = None | |
| self.current_session_id = None | |
| def handle_session(self, doctor_name): | |
| """Create a new session if doctor name changes or no session exists""" | |
| # Always create a new session | |
| self.current_session_id = self.db.create_session(doctor_name) | |
| self.current_doctor = doctor_name | |
| return self.current_session_id | |
| def get_user_identifier(self, request: gr.Request): | |
| """Create a unique user identifier from IP and user agent""" | |
| if request is None: | |
| return "anonymous" | |
| identifier = f"{request.client.host}_{request.headers.get('User-Agent', 'unknown')}" | |
| return hashlib.sha256(identifier.encode()).hexdigest()[:32] | |
| def detect_message_type(self, message): | |
| """Use ChatGPT to detect if the message is a basic interaction or a knowledge query""" | |
| try: | |
| response = self.client.chat.completions.create( | |
| model="gpt-4", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": """Analyze the following message and determine if it's: | |
| 1. A basic interaction like hello, thanks, how are you(greetings, thanks, farewell, etc.) | |
| 2. A question or request for information | |
| return only 'basic' if the message is only for greeting, or return query | |
| Respond with just the type: 'basic' or 'query'""" | |
| }, | |
| {"role": "user", "content": message} | |
| ], | |
| temperature=0.3, | |
| max_tokens=10 | |
| ) | |
| return response.choices[0].message.content.strip().lower() | |
| except Exception as e: | |
| print(f'error encountered. returning query.\nError: {str(e)}') | |
| return "query" | |
| def get_chatgpt_response(self, message, history): | |
| """Get a response from ChatGPT""" | |
| try: | |
| chat_history = [] | |
| for human, assistant in history: | |
| chat_history.extend([ | |
| {"role": "user", "content": human}, | |
| {"role": "assistant", "content": assistant} | |
| ]) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": """ "You are an expert assistant for biomedical question-answering tasks. " | |
| "You will be provided with context retrieved from medical literature." | |
| "The medical literature is all from PubMed Open Access Articles. " | |
| "Use this context to answer the question as accurately as possible. " | |
| "The response might not be added precisely, so try to derive the answers from it as much as possible." | |
| "If the context does not contain the required information, explain why. " | |
| "Provide a concise and accurate answer """ | |
| } | |
| ] + chat_history + [ | |
| {"role": "user", "content": message} | |
| ] | |
| response = self.client.chat.completions.create( | |
| model="gpt-4", | |
| messages=messages, | |
| temperature=0.7, | |
| max_tokens=500 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"I apologize, but I encountered an error: {str(e)}" | |
| def synthesize_answer(self, query, context_docs, history): | |
| """Synthesize an answer from multiple context documents using ChatGPT""" | |
| try: | |
| context = "\n\n".join([doc.page_content for doc in context_docs]) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": """You are a medical expert assistant. Using the provided context, | |
| synthesize a comprehensive, accurate answer. If the context doesn't contain | |
| enough relevant information, say so and provide general medical knowledge. | |
| Always maintain a professional yet accessible tone.""" | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"""Context information:\n{context}\n\n | |
| Based on this context and your medical knowledge, please answer the following question:\n{query}""" | |
| } | |
| ] | |
| response = self.client.chat.completions.create( | |
| model="gpt-4", | |
| messages=messages, | |
| temperature=0.2, | |
| max_tokens=1000 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"I apologize, but I encountered an error synthesizing the answer: {str(e)}" | |
| def format_sources_for_db(self, sources): | |
| """Format sources for database storage""" | |
| if not sources: | |
| return None | |
| sources_data = [] | |
| for doc in sources: | |
| sources_data.append({ | |
| 'title': doc.metadata.get('title'), | |
| 'source': doc.metadata.get('source'), | |
| 'timestamp': datetime.utcnow().isoformat() | |
| }) | |
| return json.dumps(sources_data) | |
| def respond(self, message, history, doctor_name: str, request: gr.Request = None): | |
| """Main response function for the chatbot""" | |
| try: | |
| # Don't reuse sessions - ensure we're using the current session ID | |
| if not hasattr(self, 'current_session_id') or not self.current_session_id: | |
| self.current_session_id = self.db.create_session(doctor_name) | |
| # Log user message | |
| self.db.log_message( | |
| session_id=self.current_session_id, | |
| message=message, | |
| is_user=True | |
| ) | |
| # Rest of your existing respond method remains the same... | |
| message_type = self.detect_message_type(message) | |
| if message_type == "basic": | |
| response = self.get_chatgpt_response(message, history) | |
| self.db.log_message( | |
| session_id=self.current_session_id, | |
| message=response, | |
| is_user=False | |
| ) | |
| return response | |
| retriever_response = self.qa_system.ask(message) | |
| if "error" in retriever_response: | |
| response = self.get_chatgpt_response(message, history) | |
| self.db.log_message( | |
| session_id=self.current_session_id, | |
| message=response, | |
| is_user=False | |
| ) | |
| return response | |
| if retriever_response.get("context") and len(retriever_response["context"]) > 0: | |
| synthesized_answer = self.synthesize_answer( | |
| message, | |
| retriever_response["context"], | |
| history | |
| ) | |
| sources = self.format_sources(retriever_response["context"]) | |
| final_response = synthesized_answer + sources | |
| self.db.log_message( | |
| session_id=self.current_session_id, | |
| message=final_response, | |
| is_user=False, | |
| sources=self.format_sources_for_db(retriever_response["context"]) | |
| ) | |
| return final_response | |
| else: | |
| response = self.get_chatgpt_response(message, history) | |
| fallback_response = "I couldn't find specific information about this in my knowledge base, but here's what I can tell you:\n\n" + response | |
| self.db.log_message( | |
| session_id=self.current_session_id, | |
| message=fallback_response, | |
| is_user=False | |
| ) | |
| return fallback_response | |
| except Exception as e: | |
| error_message = f"I apologize, but I encountered an error: {str(e)}" | |
| if self.current_session_id: | |
| self.db.log_message( | |
| session_id=self.current_session_id, | |
| message=error_message, | |
| is_user=False | |
| ) | |
| return error_message | |
| def format_sources(self, sources): | |
| """Format sources into a readable string""" | |
| if not sources: | |
| return "" | |
| formatted = "\n\n📚 Sources Used:\n" | |
| seen_sources = set() | |
| for doc in sources: | |
| source_id = (doc.metadata.get('title', ''), doc.metadata.get('source', '')) | |
| if source_id not in seen_sources: | |
| seen_sources.add(source_id) | |
| formatted += f"\n• {doc.metadata.get('title', 'Untitled')}\n" | |
| if doc.metadata.get('source'): | |
| formatted += f" Link: {doc.metadata['source']}\n" | |
| return formatted | |
| def transcribe_audio(self, audio_path): | |
| """Transcribe audio using OpenAI Whisper""" | |
| try: | |
| with open(audio_path, "rb") as audio_file: | |
| transcript = self.client.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=audio_file | |
| ) | |
| return transcript.text | |
| except Exception as e: | |
| print(f"Error transcribing audio: {str(e)}") | |
| return None | |
| def process_audio_input(self, audio_path, history, doctor_name): | |
| """Process audio input and return both text and audio response""" | |
| try: | |
| # Transcribe the audio | |
| transcription = self.transcribe_audio(audio_path) | |
| if not transcription: | |
| return "Sorry, I couldn't understand the audio.", None | |
| # Get text response | |
| text_response = self.respond(transcription, history, doctor_name) | |
| # Convert response to speech | |
| # audio_response = self.text_to_speech(text_response) | |
| return text_response | |
| except Exception as e: | |
| return f"Error processing audio: {str(e)}" | |
| def main(): | |
| med_chatbot = MedicalChatbot() | |
| with gr.Blocks(theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# Medical Knowledge Assistant") | |
| gr.Markdown("Ask me anything about medical topics using text or voice.") | |
| session_state = gr.State() | |
| doctor_state = gr.State() | |
| # Doctor Name Input | |
| with gr.Row(): | |
| doctor_name = gr.Textbox( | |
| label="Doctor Name", | |
| placeholder="Enter your name", | |
| show_label=True, | |
| container=True, | |
| scale=2, | |
| interactive=True | |
| ) | |
| # Main Chat Interface | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot(height=400) | |
| # Text Input Area | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| placeholder="Type your message here...", | |
| scale=8 | |
| ) | |
| send_button = gr.Button("Send", scale=1) | |
| # Audio Input Area | |
| with gr.Row(): | |
| audio = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="Voice Message", | |
| interactive=True | |
| ) | |
| # Audio Output Area | |
| audio_output = gr.Audio( | |
| label="AI Voice Response", | |
| visible=True, | |
| interactive=False | |
| ) | |
| # Initialize session handler | |
| def init_session(doctor, current_doctor): | |
| if not doctor or doctor == current_doctor: | |
| return None, current_doctor | |
| return med_chatbot.db.create_session(doctor), doctor | |
| # Text input handler | |
| def on_text_submit(message, history, doctor, session_id, current_doctor): | |
| if not session_id or doctor != current_doctor: | |
| session_id, current_doctor = init_session(doctor, current_doctor) | |
| med_chatbot.current_session_id = session_id | |
| response = med_chatbot.respond(message, history, doctor) | |
| history.append((message, response)) | |
| return "", history, None, session_id, current_doctor | |
| # Audio input handler with numpy array | |
| def on_audio_submit(audio_path, history, doctor, session_id, current_doctor): | |
| try: | |
| if audio_path is None: | |
| return history, None, session_id, current_doctor | |
| # Initialize session if needed | |
| if not session_id or doctor != current_doctor: | |
| session_id, current_doctor = init_session(doctor, current_doctor) | |
| # Set current session | |
| med_chatbot.current_session_id = session_id | |
| # Transcribe the audio | |
| transcription = med_chatbot.transcribe_audio(audio_path) | |
| if not transcription: | |
| return history, None, session_id, current_doctor | |
| # Log the transcription as a user message in the database | |
| med_chatbot.db.log_message( | |
| session_id=session_id, | |
| message=transcription, | |
| is_user=True | |
| ) | |
| # Append transcription to the chatbot history | |
| history.append((f"🎤 {transcription}", None)) # User message, no AI response yet | |
| # Process the transcription as a user query | |
| ai_response = med_chatbot.respond(transcription, history, doctor) | |
| # Append AI response to the chatbot history | |
| history[-1] = (f"🎤 {transcription}", ai_response) # Update with AI response | |
| # Log the AI response in the database | |
| med_chatbot.db.log_message( | |
| session_id=session_id, | |
| message=ai_response, | |
| is_user=False | |
| ) | |
| return history, session_id, current_doctor | |
| except Exception as e: | |
| print(f"Error processing audio: {str(e)}") | |
| return history, None, session_id, current_doctor | |
| # Set up event handlers | |
| doctor_name.submit( | |
| fn=init_session, | |
| inputs=[doctor_name, doctor_state], | |
| outputs=[session_state, doctor_state] | |
| ) | |
| send_button.click( | |
| fn=on_text_submit, | |
| inputs=[text_input, chatbot, doctor_name, session_state, doctor_state], | |
| outputs=[text_input, chatbot, audio_output, session_state, doctor_state] | |
| ) | |
| text_input.submit( | |
| fn=on_text_submit, | |
| inputs=[text_input, chatbot, doctor_name, session_state, doctor_state], | |
| outputs=[text_input, chatbot, audio_output, session_state, doctor_state] | |
| ) | |
| # Audio submission | |
| audio.stop_recording( | |
| fn=on_audio_submit, | |
| inputs=[audio, chatbot, doctor_name, session_state, doctor_state], | |
| outputs=[chatbot, session_state, doctor_state] | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Hello, how are you?", "Dr. Smith"], | |
| ["What are the common causes of iron deficiency anemia?", "Dr. Smith"], | |
| ["What are the latest treatments for type 2 diabetes?", "Dr. Smith"], | |
| ["Can you explain the relationship between diet and heart disease?", "Dr. Smith"] | |
| ], | |
| inputs=[text_input, doctor_name] | |
| ) | |
| interface.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |
| if __name__ == "__main__": | |
| main() |