sehatguard commited on
Commit
08894f0
·
verified ·
1 Parent(s): 6ad770b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -1,16 +1,24 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
- # Load a pre-trained question-answering model
5
- model = pipeline("question-answering", model="deepset/roberta-base-squad2")
6
 
7
- # Function to answer based on symptoms
 
 
 
 
 
 
 
8
  def diagnose(symptoms):
9
- question = f"Given these symptoms: {symptoms}, what is the possible diagnosis?"
10
- answer = model(question=question, context="The model will use this context to infer the diagnosis.") # Placeholder context
11
- return answer['answer']
 
12
 
13
- # Triage function (same as before)
14
  def triage(symptoms):
15
  if "shortness of breath" in symptoms or "chest pain" in symptoms:
16
  return "Urgent: Seek immediate medical attention."
@@ -21,7 +29,7 @@ def triage(symptoms):
21
 
22
  # Combine diagnosis and triage into one function
23
  def full_check(symptoms):
24
- diagnosis = diagnose(symptoms) # Get diagnosis from QA model
25
  severity = triage(symptoms) # Get severity level
26
  return diagnosis, severity
27
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ # Load a pre-trained zero-shot classification model (e.g., BART)
5
+ model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
6
 
7
+ # Broad dynamic categories for diagnosis (you can modify these if needed)
8
+ candidate_labels = [
9
+ "respiratory infection", "viral infection", "bacterial infection", "autoimmune disease",
10
+ "cardiovascular disease", "endocrine disorders", "gastrointestinal disorders",
11
+ "neurological disorders", "skin disorders", "genetic disorders", "cancer", "kidney disease"
12
+ ]
13
+
14
+ # Function to diagnose based on symptoms
15
  def diagnose(symptoms):
16
+ result = model(symptoms, candidate_labels=candidate_labels)
17
+ diagnosis_category = result['labels'][0] # Top predicted category
18
+ confidence = result['scores'][0] # Confidence score
19
+ return f"Predicted diagnosis category: {diagnosis_category} with confidence: {confidence:.2f}"
20
 
21
+ # Triage function to assess symptom severity
22
  def triage(symptoms):
23
  if "shortness of breath" in symptoms or "chest pain" in symptoms:
24
  return "Urgent: Seek immediate medical attention."
 
29
 
30
  # Combine diagnosis and triage into one function
31
  def full_check(symptoms):
32
+ diagnosis = diagnose(symptoms) # Get diagnosis category
33
  severity = triage(symptoms) # Get severity level
34
  return diagnosis, severity
35