Spaces:
Runtime error
Runtime error
Update chatbot.py
Browse files- chatbot.py +29 -41
chatbot.py
CHANGED
|
@@ -167,12 +167,8 @@ readable_patient_data = transform_patient_data(patient_data)
|
|
| 167 |
# Function to extract details from the input prompt
|
| 168 |
def extract_details_from_prompt(prompt):
|
| 169 |
pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
condition = match.group(1).capitalize()
|
| 173 |
-
patient_id = int(match.group(2))
|
| 174 |
-
return condition, patient_id
|
| 175 |
-
return None, None
|
| 176 |
|
| 177 |
# Function to fetch specific patient data based on the condition and ID
|
| 178 |
def get_specific_patient_data(patient_data, condition, patient_id):
|
|
@@ -191,18 +187,25 @@ def get_specific_patient_data(patient_data, condition, patient_id):
|
|
| 191 |
break
|
| 192 |
return specific_data
|
| 193 |
|
| 194 |
-
#
|
| 195 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
if input_type == "Voice":
|
| 197 |
return gr.update(visible=True), gr.update(visible=False)
|
| 198 |
else:
|
| 199 |
return gr.update(visible=False), gr.update(visible=True)
|
| 200 |
|
| 201 |
-
#
|
| 202 |
def cleanup_response(response):
|
| 203 |
# Extract only the part after "Answer:" and remove any trailing spaces
|
| 204 |
answer_start = response.find("Answer:")
|
| 205 |
-
if
|
| 206 |
response = response[answer_start + len("Answer:"):].strip()
|
| 207 |
return response
|
| 208 |
|
|
@@ -213,38 +216,23 @@ def chatbot(audio, input_type, text):
|
|
| 213 |
if "error" in transcription:
|
| 214 |
return "Error transcribing audio: " + transcription["error"], None
|
| 215 |
query = transcription['text']
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
return clean_response, None
|
| 231 |
-
|
| 232 |
-
elif input_type == "Text":
|
| 233 |
-
condition, patient_id = extract_details_from_prompt(text)
|
| 234 |
-
patient_history = ""
|
| 235 |
-
if condition and patient_id:
|
| 236 |
-
patient_history = get_specific_patient_data(patient_data, condition, patient_id)
|
| 237 |
-
payload = {
|
| 238 |
-
"inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {text}"
|
| 239 |
-
}
|
| 240 |
-
response = query_huggingface(payload)
|
| 241 |
-
if isinstance(response, list):
|
| 242 |
-
raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
|
| 243 |
-
else:
|
| 244 |
-
raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
|
| 245 |
|
| 246 |
-
|
| 247 |
-
|
| 248 |
|
| 249 |
# Gradio interface for generating voice response
|
| 250 |
def generate_voice_response(tts_model, text_response):
|
|
|
|
| 167 |
# Function to extract details from the input prompt
|
| 168 |
def extract_details_from_prompt(prompt):
|
| 169 |
pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE)
|
| 170 |
+
matches = pattern.findall(prompt)
|
| 171 |
+
return [(match[0].capitalize(), int(match[1])) for match in matches]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
# Function to fetch specific patient data based on the condition and ID
|
| 174 |
def get_specific_patient_data(patient_data, condition, patient_id):
|
|
|
|
| 187 |
break
|
| 188 |
return specific_data
|
| 189 |
|
| 190 |
+
# Function to aggregate patient history for all mentioned IDs in the question
|
| 191 |
+
def get_aggregated_patient_history(patient_data, details):
|
| 192 |
+
history = ""
|
| 193 |
+
for condition, patient_id in details:
|
| 194 |
+
history += get_specific_patient_data(patient_data, condition, patient_id) + "\n"
|
| 195 |
+
return history.strip()
|
| 196 |
+
|
| 197 |
+
# Toggle visibility of input elements based on input type
|
| 198 |
+
def toggle_visibility(input_type):
|
| 199 |
if input_type == "Voice":
|
| 200 |
return gr.update(visible=True), gr.update(visible=False)
|
| 201 |
else:
|
| 202 |
return gr.update(visible=False), gr.update(visible=True)
|
| 203 |
|
| 204 |
+
# Cleanup response text
|
| 205 |
def cleanup_response(response):
|
| 206 |
# Extract only the part after "Answer:" and remove any trailing spaces
|
| 207 |
answer_start = response.find("Answer:")
|
| 208 |
+
if answer_start != -1:
|
| 209 |
response = response[answer_start + len("Answer:"):].strip()
|
| 210 |
return response
|
| 211 |
|
|
|
|
| 216 |
if "error" in transcription:
|
| 217 |
return "Error transcribing audio: " + transcription["error"], None
|
| 218 |
query = transcription['text']
|
| 219 |
+
else:
|
| 220 |
+
query = text
|
| 221 |
+
|
| 222 |
+
details = extract_details_from_prompt(query)
|
| 223 |
+
patient_history = get_aggregated_patient_history(patient_data, details)
|
| 224 |
+
|
| 225 |
+
payload = {
|
| 226 |
+
"inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {query}"
|
| 227 |
+
}
|
| 228 |
+
response = query_huggingface(payload)
|
| 229 |
+
if isinstance(response, list):
|
| 230 |
+
raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
|
| 231 |
+
else:
|
| 232 |
+
raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
clean_response = cleanup_response(raw_response)
|
| 235 |
+
return clean_response, None
|
| 236 |
|
| 237 |
# Gradio interface for generating voice response
|
| 238 |
def generate_voice_response(tts_model, text_response):
|