sehatguard commited on
Commit
8faf95d
·
verified ·
1 Parent(s): bc8c1a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -1,39 +1,43 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
- # Load the ClinicalBERT model
5
- model = pipeline("text-classification", model="emilyalsentzer/Bio_ClinicalBERT")
6
 
7
  # Function to diagnose based on symptoms
8
  def diagnose(symptoms):
9
- prediction = model(symptoms)
10
- LABEL_TO_DIAGNOSIS = {
11
- "LABEL_0": "Common Cold",
12
- "LABEL_1": "Flu",
13
- "LABEL_2": "Migraine",
14
- "LABEL_3": "Pneumonia",
15
- }
16
- label = prediction[0]['label']
17
- diagnosis = LABEL_TO_DIAGNOSIS.get(label, "Diagnosis not found")
18
- return f"Possible diagnosis: {diagnosis}"
19
 
20
- # Triage function based on symptoms
21
  def triage(symptoms):
22
- if "chest pain" in symptoms or "shortness of breath" in symptoms:
23
  return "Urgent: Seek immediate medical attention."
24
  elif "fever" in symptoms and "cough" in symptoms:
25
  return "Mild: Likely a viral infection, monitor symptoms."
26
  else:
27
  return "Non-urgent: No immediate concern."
28
 
29
- # Combine Diagnosis and Triage
30
  def full_check(symptoms):
31
- diagnosis = diagnose(symptoms)
32
- severity = triage(symptoms)
33
  return diagnosis, severity
34
 
35
- # Create the Gradio interface
36
- iface = gr.Interface(fn=full_check, inputs="text", outputs=["text", "text"], title="Sehat Guard - Symptom Checker")
 
 
 
 
 
 
37
 
38
  # Launch the app
39
  iface.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ # Load a general-purpose zero-shot classification model (can be BioBERT or another fine-tuned model)
5
+ model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
6
 
7
  # Function to diagnose based on symptoms
8
  def diagnose(symptoms):
9
+ # Using zero-shot classification, no predefined labels are needed; model dynamically predicts the best category
10
+ result = model(symptoms, candidate_labels=["disease", "illness", "symptom", "medical condition"])
11
+
12
+ # Dynamic prediction
13
+ diagnosis = result['labels'][0] # Top predicted label
14
+ confidence = result['scores'][0] # Confidence score
15
+
16
+ return f"Predicted diagnosis: {diagnosis} with confidence: {confidence:.2f}"
 
 
17
 
18
+ # Triage function to assess symptom severity
19
  def triage(symptoms):
20
+ if "shortness of breath" in symptoms or "chest pain" in symptoms:
21
  return "Urgent: Seek immediate medical attention."
22
  elif "fever" in symptoms and "cough" in symptoms:
23
  return "Mild: Likely a viral infection, monitor symptoms."
24
  else:
25
  return "Non-urgent: No immediate concern."
26
 
27
+ # Combine diagnosis and triage into one function
28
  def full_check(symptoms):
29
+ diagnosis = diagnose(symptoms) # Get diagnosis
30
+ severity = triage(symptoms) # Get severity
31
  return diagnosis, severity
32
 
33
+ # Create Gradio interface
34
+ iface = gr.Interface(
35
+ fn=full_check,
36
+ inputs="text",
37
+ outputs=["text", "text"],
38
+ title="Sehat Guard - Symptom Checker",
39
+ description="Enter your symptoms to get a possible diagnosis and severity of the condition."
40
+ )
41
 
42
  # Launch the app
43
  iface.launch()