GSMK commited on
Commit
dc4b134
·
verified ·
1 Parent(s): 78eb65a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -40
app.py CHANGED
@@ -1,26 +1,45 @@
1
- import gradio as gr
2
  import torch
3
  from PIL import Image
4
  from transformers import CLIPModel, CLIPProcessor
5
  from transformers import BlipProcessor, BlipForConditionalGeneration
6
- from datasets import load_dataset
 
 
 
 
 
 
 
7
 
8
  device = "cpu"
9
 
10
- print("Loading models...")
 
 
11
 
12
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
14
 
15
- blip_processor = BlipProcessor.from_pretrained(
16
- "Salesforce/blip-image-captioning-base"
17
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- blip_model = BlipForConditionalGeneration.from_pretrained(
20
- "Salesforce/blip-image-captioning-base"
21
- )
22
 
23
- # DATASET LABELS
24
  DATASETS = {
25
  "medical": ["pneumonia", "Normal"],
26
  "skin_cancer": ["Normal Skin", "eczema", "Melanoma", "psoriasis"],
@@ -35,11 +54,33 @@ templates = {
35
  "agriculture": "a close-up leaf showing signs of {}"
36
  }
37
 
38
- def analyze(image, dataset):
39
 
40
- labels = DATASETS[dataset]
 
 
 
 
 
 
 
 
 
 
41
 
42
- text_queries = [templates[dataset].format(l) for l in labels]
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  inputs = clip_processor(
45
  text=text_queries,
@@ -49,40 +90,38 @@ def analyze(image, dataset):
49
  )
50
 
51
  with torch.no_grad():
 
52
  outputs = clip_model(**inputs)
 
53
  probs = outputs.logits_per_image.softmax(dim=1)
54
 
55
  conf, idx = torch.max(probs, dim=1)
56
 
57
- detected_class = labels[idx.item()]
58
 
59
- # BLIP caption generation
60
- blip_inputs = blip_processor(images=image, return_tensors="pt")
61
 
62
- with torch.no_grad():
63
- ids = blip_model.generate(**blip_inputs)
64
 
65
- caption = blip_processor.decode(ids[0], skip_special_tokens=True)
 
 
 
66
 
67
- return detected_class, float(conf), caption
 
 
 
68
 
 
69
 
70
- interface = gr.Interface(
71
- fn=analyze,
72
- inputs=[
73
- gr.Image(type="pil", label="Upload Image"),
74
- gr.Dropdown(
75
- choices=list(DATASETS.keys()),
76
- label="Dataset Type"
77
- )
78
- ],
79
- outputs=[
80
- gr.Text(label="Predicted Class"),
81
- gr.Number(label="Confidence"),
82
- gr.Textbox(label="Description")
83
- ],
84
- title="AI Image Diagnostic System",
85
- description="CLIP + BLIP based AI diagnostic model"
86
- )
87
 
88
- interface.launch()
 
1
+ import streamlit as st
2
  import torch
3
  from PIL import Image
4
  from transformers import CLIPModel, CLIPProcessor
5
  from transformers import BlipProcessor, BlipForConditionalGeneration
6
+
7
+ st.set_page_config(
8
+ page_title="AI Image Diagnostic System",
9
+ layout="wide"
10
+ )
11
+
12
+ st.title("🔬 AI Image Diagnostic System")
13
+ st.write("CLIP + BLIP based AI diagnostic platform")
14
 
15
  device = "cpu"
16
 
17
+ # Load models once
18
+ @st.cache_resource
19
+ def load_models():
20
 
21
+ clip_model = CLIPModel.from_pretrained(
22
+ "openai/clip-vit-base-patch32"
23
+ )
24
 
25
+ clip_processor = CLIPProcessor.from_pretrained(
26
+ "openai/clip-vit-base-patch32"
27
+ )
28
+
29
+ blip_processor = BlipProcessor.from_pretrained(
30
+ "Salesforce/blip-image-captioning-base"
31
+ )
32
+
33
+ blip_model = BlipForConditionalGeneration.from_pretrained(
34
+ "Salesforce/blip-image-captioning-base"
35
+ )
36
+
37
+ return clip_model, clip_processor, blip_processor, blip_model
38
+
39
+
40
+ clip_model, clip_processor, blip_processor, blip_model = load_models()
41
 
 
 
 
42
 
 
43
  DATASETS = {
44
  "medical": ["pneumonia", "Normal"],
45
  "skin_cancer": ["Normal Skin", "eczema", "Melanoma", "psoriasis"],
 
54
  "agriculture": "a close-up leaf showing signs of {}"
55
  }
56
 
 
57
 
58
+ st.sidebar.header("Settings")
59
+
60
+ dataset_key = st.sidebar.selectbox(
61
+ "Select Dataset Type",
62
+ list(DATASETS.keys())
63
+ )
64
+
65
+ uploaded_file = st.file_uploader(
66
+ "Upload Image",
67
+ type=["jpg", "jpeg", "png"]
68
+ )
69
 
70
+ if uploaded_file:
71
+
72
+ image = Image.open(uploaded_file).convert("RGB")
73
+
74
+ col1, col2 = st.columns(2)
75
+
76
+ with col1:
77
+ st.image(image, caption="Uploaded Image", use_column_width=True)
78
+
79
+ labels = DATASETS[dataset_key]
80
+
81
+ text_queries = [
82
+ templates[dataset_key].format(l) for l in labels
83
+ ]
84
 
85
  inputs = clip_processor(
86
  text=text_queries,
 
90
  )
91
 
92
  with torch.no_grad():
93
+
94
  outputs = clip_model(**inputs)
95
+
96
  probs = outputs.logits_per_image.softmax(dim=1)
97
 
98
  conf, idx = torch.max(probs, dim=1)
99
 
100
+ predicted_class = labels[idx.item()]
101
 
102
+ with col2:
 
103
 
104
+ st.success(f"Prediction: {predicted_class}")
 
105
 
106
+ st.metric(
107
+ label="Confidence",
108
+ value=f"{conf.item():.2%}"
109
+ )
110
 
111
+ blip_inputs = blip_processor(
112
+ images=image,
113
+ return_tensors="pt"
114
+ )
115
 
116
+ with torch.no_grad():
117
 
118
+ caption_ids = blip_model.generate(**blip_inputs)
119
+
120
+ caption = blip_processor.decode(
121
+ caption_ids[0],
122
+ skip_special_tokens=True
123
+ )
124
+
125
+ st.subheader("Generated Description")
 
 
 
 
 
 
 
 
 
126
 
127
+ st.write(caption)