ProfRom commited on
Commit
faa5acc
·
verified ·
1 Parent(s): d2ca300

Stetson - Final submission

Browse files
Files changed (1) hide show
  1. app.py +74 -105
app.py CHANGED
@@ -1,135 +1,104 @@
1
- import gradio as gr
2
  import torch
3
  from transformers import pipeline
 
 
4
 
 
 
5
 
6
- PROJECT_LABELS = ["person", "skis", "cell phone", "spoon", "stop sign"]
7
-
8
-
9
- def get_device():
10
- return 0 if torch.cuda.is_available() else -1
11
-
12
-
13
- DEVICE = get_device()
14
-
15
- detector = pipeline(
16
- task="object-detection",
17
- model="facebook/detr-resnet-50",
18
- device=DEVICE,
19
  )
20
 
21
- captioner = pipeline(task="image-text-to-text", model="Salesforce/blip-image-captioning-base", device=DEVICE)
 
 
 
 
 
22
 
23
- zero_shot = pipeline(
24
- task="zero-shot-classification",
25
- model="typeform/distilbert-base-uncased-mnli",
26
- device=DEVICE,
 
27
  )
28
 
29
 
30
- def normalize_label(text):
31
- text = str(text).strip().lower()
32
- aliases = {
33
- "cellphone": "cell phone",
34
- "mobile phone": "cell phone",
35
- "phone": "cell phone",
36
- "human": "person",
37
- "people": "person",
38
- "stopsign": "stop sign",
39
- }
40
- return aliases.get(text, text)
41
-
42
-
43
- def run_object_detection(image):
44
- detections = detector(image)
45
- rows = []
46
-
47
- for detection in detections[:10]:
48
- box = detection.get("box", {})
49
- rows.append(
50
- [
51
- normalize_label(detection.get("label", "")),
52
- round(float(detection.get("score", 0.0)), 4),
53
- round(float(box.get("xmin", 0.0)), 1),
54
- round(float(box.get("ymin", 0.0)), 1),
55
- round(float(box.get("xmax", 0.0)), 1),
56
- round(float(box.get("ymax", 0.0)), 1),
57
- ]
58
- )
59
-
60
- return rows
61
-
62
-
63
- def run_captioning(image):
64
- result = captioner(image, "")
65
- if not result:
66
- return ""
67
- return result[0]["generated_text"].strip()
68
-
69
-
70
- def classify_caption(caption):
71
- if not caption:
72
- return "No caption generated.", []
73
-
74
- result = zero_shot(caption, PROJECT_LABELS)
75
- top_label = normalize_label(result["labels"][0])
76
- top_score = float(result["scores"][0])
77
-
78
- score_rows = [
79
- [normalize_label(label), round(float(score), 4)]
80
- for label, score in zip(result["labels"], result["scores"])
81
- ]
82
 
83
- return f"{top_label} ({top_score:.4f})", score_rows
 
 
84
 
 
 
 
 
 
 
 
85
 
86
- def analyze_image(image):
87
- if image is None:
88
- return [], "", "", []
89
 
90
- detection_rows = run_object_detection(image)
91
- caption = run_captioning(image)
92
- top_caption_label, zero_shot_rows = classify_caption(caption)
93
 
94
- return detection_rows, caption, top_caption_label, zero_shot_rows
95
 
 
 
96
 
97
- with gr.Blocks(title="Multimodal Image Analysis") as demo:
98
- gr.Markdown("# Multimodal Image Analysis")
99
  gr.Markdown(
100
- "Upload an image to run object detection, image captioning, "
101
- "and zero-shot classification of the generated caption."
 
 
 
 
 
102
  )
103
 
104
  with gr.Row():
105
- image_input = gr.Image(type="pil", label="Upload Image")
106
-
107
- with gr.Column():
108
- detection_output = gr.Dataframe(
109
- headers=["label", "score", "xmin", "ymin", "xmax", "ymax"],
110
- label="Object Detection Results",
111
- interactive=False,
 
 
 
 
 
112
  )
113
- caption_output = gr.Textbox(label="Generated Caption")
114
- top_label_output = gr.Textbox(label="Caption-Based Class Prediction")
115
- zero_shot_output = gr.Dataframe(
116
- headers=["candidate label", "score"],
117
- label="Zero-Shot Classification Scores",
118
- interactive=False,
119
  )
120
 
121
- analyze_button = gr.Button("Analyze Image", variant="primary")
122
- analyze_button.click(
123
  fn=analyze_image,
124
- inputs=image_input,
125
- outputs=[
126
- detection_output,
127
- caption_output,
128
- top_label_output,
129
- zero_shot_output,
130
- ],
131
  )
132
 
 
 
 
 
 
 
 
 
 
133
 
134
  if __name__ == "__main__":
135
  demo.launch()
 
 
1
  import torch
2
  from transformers import pipeline
3
+ from PIL import Image
4
+ import gradio as gr
5
 
6
+ # ── Device: GPU if available, otherwise CPU ───────────────────────────────────
7
+ device = 0 if torch.cuda.is_available() else -1
8
 
9
+ # ── Pipeline 1: Image Captioning (BLIP) ──────────────────────────────────────
10
+ captioner = pipeline(
11
+ "image-to-text",
12
+ model="Salesforce/blip-image-captioning-base",
13
+ device=device,
 
 
 
 
 
 
 
 
14
  )
15
 
16
+ # ── Pipeline 2: Image Classification (ViT) ───────────────────────────────────
17
+ classifier = pipeline(
18
+ "image-classification",
19
+ model="google/vit-base-patch16-224",
20
+ device=device,
21
+ )
22
 
23
+ # ── Pipeline 3: Sentiment Analysis on Caption (DistilBERT) ───────────────────
24
+ sentiment_analyzer = pipeline(
25
+ "sentiment-analysis",
26
+ model="distilbert-base-uncased-finetuned-sst-2-english",
27
+ device=device,
28
  )
29
 
30
 
31
+ def analyze_image(image: Image.Image):
32
+ """Run all three pipelines and return formatted results."""
33
+ if image is None:
34
+ return "Upload an image to begin.", "", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Pipeline 1 - BLIP caption
37
+ caption_result = captioner(image)
38
+ caption = caption_result[0]["generated_text"]
39
 
40
+ # Pipeline 2 - ViT top-5 classifications
41
+ cls_results = classifier(image)
42
+ top5_lines = [
43
+ f"{i + 1}. {r['label'].replace('_', ' ').title()}: {r['score']:.2%}"
44
+ for i, r in enumerate(cls_results[:5])
45
+ ]
46
+ top5_text = "\n".join(top5_lines)
47
 
48
+ # Pipeline 3 - DistilBERT sentiment on the caption
49
+ sent = sentiment_analyzer(caption)[0]
50
+ sentiment_text = f"{sent['label'].capitalize()} (confidence: {sent['score']:.2%})"
51
 
52
+ return caption, top5_text, sentiment_text
 
 
53
 
 
54
 
55
+ # ── Gradio UI (Blocks for layout control) ────────────────────────────────────
56
+ with gr.Blocks(title="Multimodal AI Image Analyzer") as demo:
57
 
 
 
58
  gr.Markdown(
59
+ """
60
+ # Multimodal AI Image Analyzer
61
+ Upload any image to run three AI pipelines simultaneously:
62
+ - **BLIP** generates a natural-language caption (computer vision → text)
63
+ - **ViT** classifies the image content from 1,000 ImageNet categories
64
+ - **DistilBERT** analyzes the sentiment of the generated caption (NLP)
65
+ """
66
  )
67
 
68
  with gr.Row():
69
+ with gr.Column(scale=1):
70
+ img_input = gr.Image(type="pil", label="Input Image")
71
+ analyze_btn = gr.Button("Analyze", variant="primary")
72
+
73
+ with gr.Column(scale=1):
74
+ caption_out = gr.Textbox(
75
+ label="Pipeline 1 — BLIP Caption (image-to-text)",
76
+ lines=3,
77
+ )
78
+ cls_out = gr.Textbox(
79
+ label="Pipeline 2 — ViT Top-5 Classifications (image-classification)",
80
+ lines=6,
81
  )
82
+ sentiment_out = gr.Textbox(
83
+ label="Pipeline 3 — DistilBERT Caption Sentiment (sentiment-analysis)",
84
+ lines=2,
 
 
 
85
  )
86
 
87
+ analyze_btn.click(
 
88
  fn=analyze_image,
89
+ inputs=img_input,
90
+ outputs=[caption_out, cls_out, sentiment_out],
 
 
 
 
 
91
  )
92
 
93
+ gr.Markdown(
94
+ """
95
+ ---
96
+ **Models used:**
97
+ [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base) ·
98
+ [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) ·
99
+ [distilbert-base-uncased-finetuned-sst-2-english](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)
100
+ """
101
+ )
102
 
103
  if __name__ == "__main__":
104
  demo.launch()