VaneshDev commited on
Commit
fd393fb
·
verified ·
1 Parent(s): 7303d3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
- from PIL import Image
3
  import torch
4
  from torchvision import models, transforms
5
- import fitz # PyMuPDF for better PDF parsing
6
- import logging
7
  import os
 
8
 
9
- # Set up logging
 
10
  logging.basicConfig(level=logging.DEBUG)
11
  logger = logging.getLogger(__name__)
12
 
@@ -20,23 +19,23 @@ conditions = [
20
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
21
  ]
22
 
23
- # Load and configure the model
24
- model = models.densenet121(weights="IMAGENET1K_V1") # DenseNet pre-trained on ImageNet
25
- num_features = model.classifier.in_features
26
- model.classifier = torch.nn.Linear(num_features, len(conditions)) # Output for all 24 conditions
27
- model.eval()
28
-
29
- # Define device
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- model = model.to(device)
 
 
 
 
 
32
 
33
- # Load model state if available, otherwise initialize
34
- model_path = "xray_model.pth"
35
- if os.path.exists(model_path):
36
- model.load_state_dict(torch.load(model_path))
37
- logger.info(f"Loaded model from {model_path}")
38
- else:
39
- logger.info("No pre-trained model found. Initializing with random weights. Training required.")
40
 
41
  # Define image preprocessing function
42
  def preprocess_image(image):
@@ -161,6 +160,6 @@ def create_interface():
161
  xray_tab = gr.Interface(fn=predict_xray, inputs=gr.Image(label="Upload X-ray", type="pil"), outputs=[gr.HTML(), gr.HTML(), gr.HTML()])
162
  report_tab = gr.Interface(fn=analyze_report, inputs=gr.File(label="Upload Patient Report (PDF)", file_count="single"), outputs=gr.Textbox(label="Report Summary", interactive=False))
163
 
164
- gr.TabbedInterface([xray_tab, report_tab], tab_names=["X-ray Analysis", "Report Analysis"]).launch(share=True)
165
 
166
  demo = create_interface()
 
1
  import gradio as gr
 
2
  import torch
3
  from torchvision import models, transforms
 
 
4
  import os
5
+ import time
6
 
7
+ # Set up logging (optional for debugging)
8
+ import logging
9
  logging.basicConfig(level=logging.DEBUG)
10
  logger = logging.getLogger(__name__)
11
 
 
19
  "Appendicitis", "Gallstones", "Kidney Stones", "Infections", "Abdominal Aortic Aneurysm", "Diverticulitis"
20
  ]
21
 
22
+ # Function to load the model efficiently
23
+ def load_model():
24
+ model_path = "/mnt/data/densenet121-a639ec97.pth" # Set the model path
25
+ if os.path.exists(model_path):
26
+ model = models.densenet121()
27
+ model.load_state_dict(torch.load(model_path)) # Load from cached path
28
+ model.eval() # Set to evaluation mode
29
+ logger.info("Loaded model from cache.")
30
+ else:
31
+ model = models.densenet121(weights="IMAGENET1K_V1") # If not cached, download model
32
+ torch.save(model.state_dict(), model_path) # Cache the model locally
33
+ model.eval()
34
+ logger.info("Downloaded and cached the model.")
35
+ return model
36
 
37
+ # Load the model at the beginning (this will take time but only happens once)
38
+ model = load_model()
 
 
 
 
 
39
 
40
  # Define image preprocessing function
41
  def preprocess_image(image):
 
160
  xray_tab = gr.Interface(fn=predict_xray, inputs=gr.Image(label="Upload X-ray", type="pil"), outputs=[gr.HTML(), gr.HTML(), gr.HTML()])
161
  report_tab = gr.Interface(fn=analyze_report, inputs=gr.File(label="Upload Patient Report (PDF)", file_count="single"), outputs=gr.Textbox(label="Report Summary", interactive=False))
162
 
163
+ gr.TabbedInterface([xray_tab, report_tab], tab_names=["X-ray Analysis", "Report Analysis"]).launch(share=False)
164
 
165
  demo = create_interface()