fcakyon commited on
Commit
28a8385
·
verified ·
1 Parent(s): fca1361

refactor: simplify utils.py

Browse files
Files changed (1) hide show
  1. utils.py +34 -149
utils.py CHANGED
@@ -1,7 +1,5 @@
1
- import os
2
- import json
3
  import gradio as gr
4
- from typing import Any, Dict, Tuple
5
  from urllib.request import urlopen, Request
6
  from io import BytesIO
7
  from PIL import Image
@@ -10,157 +8,44 @@ from functools import lru_cache
10
  _MODEL_CACHE: Dict[str, Any] = {}
11
 
12
  EXAMPLE_ITEMS = [
13
- (
14
- "https://assets.clevelandclinic.org/transform/LargeFeatureImage/cd71f4bd-81d4-45d8-a450-74df78e4477a/Apples-184940975-770x533-1_jpg",
15
- "viddexa/nsfw-mini",
16
- "Apples (mini)",
17
- ),
18
- (
19
- "https://img.freepik.com/free-photo/breast-screening-is-very-important-every-woman_329181-14953.jpg",
20
- "viddexa/nsfw-nano",
21
- "Breast screening (nano)",
22
- ),
23
- (
24
- "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSbRwt56NYsiHwrT8oS-igzgeEzp7p3Jbe2dw&s",
25
- "viddexa/nsfw-mini",
26
- "Thumbnail (mini)",
27
- ),
28
- (
29
- "https://img.freepik.com/premium-photo/portrait-beautiful-young-woman_1048944-5548042.jpg",
30
- "viddexa/nsfw-nano",
31
- "Portrait (nano)",
32
- ),
33
  ]
34
 
35
 
36
  @lru_cache(maxsize=32)
37
- def _download_image_bytes(image_url: str) -> bytes:
38
- """Download image bytes from URL with caching."""
39
- req = Request(image_url, headers={"User-Agent": "viddexa-gradio-demo/1.0"})
40
  with urlopen(req, timeout=20) as resp:
41
- return resp.read()
42
 
43
 
44
- def _load_model(model_id: str, token: str | None = None) -> Any:
45
- """Load a model and cache it."""
46
- if model_id in _MODEL_CACHE:
47
- return _MODEL_CACHE[model_id]
48
- try:
49
  from moderators.auto_model import AutoModerator
50
- model = AutoModerator.from_pretrained(model_id, token=token, use_fast=True)
51
- _MODEL_CACHE[model_id] = model
52
- return model
53
- except Exception as e:
54
- error_msg = f"Failed to load model: {model_id}. Error: {e}"
55
- if "401" in str(e):
56
- error_msg += "\n\nThis model may be private. Please ensure you have provided a valid Hugging Face token if required."
57
- raise gr.Error(error_msg)
58
-
59
-
60
- def _get_image_input(image_path: str | None, image_url: str | None) -> Image.Image:
61
- """Get image data from either an uploaded file path or a URL."""
62
- if image_url:
63
- try:
64
- data = _download_image_bytes(image_url)
65
- img = Image.open(BytesIO(data))
66
- return img.convert("RGB")
67
- except Exception as fetch_err:
68
- raise gr.Error(f"Could not download or open the image from the URL: {fetch_err}")
69
- elif image_path:
70
- img = Image.open(image_path)
71
- return img.convert("RGB")
72
- else:
73
- raise gr.Error("Please upload an image or provide an image URL.")
74
-
75
-
76
- def _format_results(results: list) -> Tuple[str, Dict[str, float], str, Dict]:
77
- """Format the model output for the Gradio interface."""
78
- if not results or "classifications" not in results[0]:
79
- return "<div class='verdict-card'>No classifications found.</div>", {}, "No classifications found.", {}
80
-
81
- classifications = results[0]["classifications"]
82
-
83
- label_output: Dict[str, float]
84
- if isinstance(classifications, dict):
85
- label_output = {str(k): float(v) for k, v in classifications.items()}
86
- else:
87
- try:
88
- label_output = {str(item['label']): float(item['score']) for item in classifications}
89
- except Exception:
90
- label_output = {}
91
-
92
- scores = {label.lower(): score for label, score in label_output.items()}
93
- nsfw_score = scores.get("nsfw", 0.0)
94
-
95
- if nsfw_score > 0.7:
96
- verdict_text = "HIGH RISK: NSFW"
97
- verdict_class = "verdict-nsfw"
98
- elif nsfw_score > 0.2:
99
- verdict_text = "MEDIUM RISK: SENSITIVE"
100
- verdict_class = "verdict-sensitive"
101
- else:
102
- verdict_text = "LOW RISK: SAFE"
103
- verdict_class = "verdict-safe"
104
-
105
- verdict_html = f"<div class='verdict-card {verdict_class}'>{verdict_text}</div>"
106
-
107
- markdown_output = "### All Scores\n---\n"
108
- for label, score in sorted(label_output.items(), key=lambda kv: kv[1], reverse=True):
109
- markdown_output += f"- **{label.capitalize()}**: {score:.4f}\n"
110
-
111
- return verdict_html, label_output, markdown_output, results[0]
112
-
113
-
114
- def analyze_image(image_path: str | None, image_url: str | None, model_choice: str,
115
- token: str | None = None, progress=gr.Progress(track_tqdm=True)):
116
- """Main inference function for the Gradio interface."""
117
- progress(0, desc="Initializing Analysis...")
118
- progress(0.2, desc="Processing Image...")
119
- input_image = _get_image_input(image_path, image_url)
120
- progress(0.5, desc=f"Loading Model: {os.path.basename(model_choice)}...")
121
- model = _load_model(model_choice, token)
122
- progress(0.8, desc="Running Inference...")
123
- results = model(input_image)
124
-
125
- json_results = [
126
- {"classifications": getattr(r, "classifications", r)}
127
- for r in results
128
- ]
129
- json_results = json.loads(json.dumps(json_results, ensure_ascii=False))
130
-
131
- progress(1, desc="Complete!")
132
- return _format_results(json_results)
133
-
134
-
135
- def analyze_image_with_status(image_path: str | None, image_url: str | None, model_choice: str,
136
- token: str | None = None, progress=gr.Progress(track_tqdm=True)):
137
- """Run analysis and return results with user-friendly status string."""
138
- verdict_html, label_scores, md_scores, json_obj = analyze_image(image_path, image_url, model_choice, token, progress)
139
- if image_url:
140
- status = f"Last analysed URL: {image_url}"
141
- elif image_path:
142
- status = "Last analysed uploaded image."
143
- else:
144
- status = "Last analysed: —"
145
- return verdict_html, label_scores, md_scores, json_obj, status
146
-
147
-
148
- def run_example_by_index(evt: gr.SelectData, token: str | None = None):
149
- """Handle gallery selection: run analysis for the selected example and update inputs."""
150
- try:
151
- idx = int(getattr(evt, "index", 0))
152
- except Exception:
153
- idx = 0
154
- idx = max(0, min(idx, len(EXAMPLE_ITEMS) - 1))
155
- url, model, caption = EXAMPLE_ITEMS[idx]
156
- verdict_html, label_scores, md_scores, json_obj = analyze_image(None, url, model, token)
157
- status = f"Last analysed example: {caption}"
158
- return (
159
- verdict_html,
160
- label_scores,
161
- md_scores,
162
- json_obj,
163
- gr.update(value=model),
164
- gr.update(value=url),
165
- status,
166
- )
 
 
 
1
  import gradio as gr
2
+ from typing import Any, Dict
3
  from urllib.request import urlopen, Request
4
  from io import BytesIO
5
  from PIL import Image
 
8
  _MODEL_CACHE: Dict[str, Any] = {}
9
 
10
  EXAMPLE_ITEMS = [
11
+ ("https://assets.clevelandclinic.org/transform/LargeFeatureImage/cd71f4bd-81d4-45d8-a450-74df78e4477a/Apples-184940975-770x533-1_jpg", "viddexa/nsfw-detection-mini"),
12
+ ("https://img.freepik.com/free-photo/breast-screening-is-very-important-every-woman_329181-14953.jpg", "viddexa/nsfw-detection-nano"),
13
+ ("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSbRwt56NYsiHwrT8oS-igzgeEzp7p3Jbe2dw&s", "viddexa/nsfw-detection-mini"),
14
+ ("https://img.freepik.com/premium-photo/portrait-beautiful-young-woman_1048944-5548042.jpg", "viddexa/nsfw-detection-nano"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  ]
16
 
17
 
18
  @lru_cache(maxsize=32)
19
+ def download_image(url: str) -> Image.Image:
20
+ """Download and return PIL Image from URL."""
21
+ req = Request(url, headers={"User-Agent": "viddexa-gradio-demo/1.0"})
22
  with urlopen(req, timeout=20) as resp:
23
+ return Image.open(BytesIO(resp.read())).convert("RGB")
24
 
25
 
26
+ def load_model(model_id: str, token: str | None = None) -> Any:
27
+ """Load model with caching."""
28
+ if model_id not in _MODEL_CACHE:
 
 
29
  from moderators.auto_model import AutoModerator
30
+ _MODEL_CACHE[model_id] = AutoModerator.from_pretrained(model_id, token=token, use_fast=True)
31
+ return _MODEL_CACHE[model_id]
32
+
33
+
34
+ def analyze(image_path: str | None, image_url: str | None, model_id: str, token: str | None = None):
35
+ """Run inference and return classification scores."""
36
+ if not image_url and not image_path:
37
+ raise gr.Error("Provide an image or URL")
38
+
39
+ img = download_image(image_url) if image_url else Image.open(image_path).convert("RGB")
40
+ model = load_model(model_id, token)
41
+ results = model(img)
42
+
43
+ classifications = results[0].classifications if hasattr(results[0], "classifications") else results[0]["classifications"]
44
+ return {str(k): float(v) for k, v in (classifications.items() if isinstance(classifications, dict) else [(c["label"], c["score"]) for c in classifications])}
45
+
46
+
47
+ def run_example(evt: gr.SelectData, token: str | None = None):
48
+ """Handle example selection."""
49
+ idx = max(0, min(int(getattr(evt, "index", 0)), len(EXAMPLE_ITEMS) - 1))
50
+ url, model = EXAMPLE_ITEMS[idx]
51
+ return analyze(None, url, model, token), gr.update(value=model), gr.update(value=url)