NeelTA commited on
Commit
22ecc00
ยท
1 Parent(s): 1366e36

url feature added

Browse files
Files changed (2) hide show
  1. app.py +55 -29
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,43 +1,69 @@
1
- import gradio as gr
2
  import numpy as np
3
- import urllib
 
 
4
  from tensorflow.keras.preprocessing import image
5
  from tensorflow.keras.models import load_model
6
 
7
- # Load model (skip optimizer state)
8
- model = load_model('my_model.h5', compile=False)
9
 
10
- # Prediction function
11
- def classify_image(img):
12
- # Resize for model
13
- img = image.array_to_img(img).resize((224, 224))
14
  img = image.img_to_array(img)
15
  img = np.expand_dims(img, axis=0)
16
  img = img / 255.0
17
-
18
  prediction = model.predict(img)[0]
19
- result = {
20
- 'CART': float(prediction[0]),
21
- 'NSFW': float(prediction[1]),
22
- 'SFW': float(prediction[2])
23
  }
24
- return result
25
-
26
- # Example images
27
- examples = [f"example{i}.jpg" for i in range(1, 9)]
28
-
29
- # Gradio interface
30
- app = gr.Interface(
31
- fn=classify_image,
32
- inputs=gr.Image(
33
- image_mode="RGB",
34
- label="๐Ÿ“‚ Drag and drop or click to upload an image"
35
- ),
36
- outputs=gr.Label(num_top_classes=3),
37
- allow_flagging="never",
 
 
 
 
 
 
 
 
 
 
 
 
38
  examples=examples,
39
- title="Simple NSFW/SFW/CART Classifier"
 
 
40
  )
41
 
42
- app.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
43
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import numpy as np
3
+ import urllib.request
4
+ from PIL import Image
5
+ from io import BytesIO
6
  from tensorflow.keras.preprocessing import image
7
  from tensorflow.keras.models import load_model
8
 
9
+ # Load the model
10
+ model = load_model("my_model.h5", compile=False)
11
 
12
+ # Common prediction function
13
+ def classify_pil_image(pil_img):
14
+ img = pil_img.resize((224, 224))
 
15
  img = image.img_to_array(img)
16
  img = np.expand_dims(img, axis=0)
17
  img = img / 255.0
 
18
  prediction = model.predict(img)[0]
19
+ return {
20
+ "CART": float(prediction[0]),
21
+ "NSFW": float(prediction[1]),
22
+ "SFW": float(prediction[2])
23
  }
24
+
25
+ # From file input (or example)
26
+ def classify_uploaded_image(file):
27
+ try:
28
+ pil_img = Image.fromarray(file).convert("RGB")
29
+ return classify_pil_image(pil_img)
30
+ except Exception as e:
31
+ return {"error": f"Upload error: {str(e)}"}
32
+
33
+ # From URL input
34
+ def classify_from_url(url):
35
+ try:
36
+ response = urllib.request.urlopen(url)
37
+ img = Image.open(BytesIO(response.read())).convert("RGB")
38
+ return classify_pil_image(img)
39
+ except Exception as e:
40
+ return {"error": f"URL error: {str(e)}"}
41
+
42
+ # Example images for file-based interface
43
+ examples = [[f"example{i}.jpg"] for i in range(1, 9)]
44
+
45
+ # Upload tab (classic layout with examples)
46
+ upload_interface = gr.Interface(
47
+ fn=classify_uploaded_image,
48
+ inputs=gr.Image(type="numpy", label="Upload or drag an image"),
49
+ outputs=gr.Label(num_top_classes=3, label="Prediction"),
50
  examples=examples,
51
+ title="Simple NSFW/SFW/CART Classifier",
52
+ allow_flagging="never",
53
+ cache_examples=False
54
  )
55
 
56
+ # URL tab (simple textbox interface)
57
+ url_interface = gr.Interface(
58
+ fn=classify_from_url,
59
+ inputs=gr.Textbox(label="Paste Image URL"),
60
+ outputs=gr.Label(num_top_classes=3, label="Prediction"),
61
+ allow_flagging="never",
62
+ cache_examples=False
63
+ )
64
 
65
+ # Tabs wrapper to combine them
66
+ gr.TabbedInterface(
67
+ [upload_interface, url_interface],
68
+ tab_names=["๐Ÿ“ค Upload Image", "๐ŸŒ Image URL"]
69
+ ).launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- tensorflow==2.18.1 # latest stable from TensorFlow GitHub :contentReference[oaicite:1]{index=1}
2
  opencv-python-headless==4.11.0.86 # latest from PyPI :contentReference[oaicite:2]{index=2}
3
  gradio==5.34.0 # just released today on PyPI :contentReference[oaicite:3]{index=3}
4
  numpy>=2.0.2 # latest major release as of June 7, 2025 :contentReference[oaicite:4]{index=4}
 
1
+ tensorflow>=2.10 # latest stable from TensorFlow GitHub :contentReference[oaicite:1]{index=1}
2
  opencv-python-headless==4.11.0.86 # latest from PyPI :contentReference[oaicite:2]{index=2}
3
  gradio==5.34.0 # just released today on PyPI :contentReference[oaicite:3]{index=3}
4
  numpy>=2.0.2 # latest major release as of June 7, 2025 :contentReference[oaicite:4]{index=4}