Hayloo9838 commited on
Commit
ac11708
·
verified ·
1 Parent(s): cff6d10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -1,29 +1,29 @@
1
- from fastapi import FastAPI, UploadFile
2
- from fastapi.responses import JSONResponse
3
  from PIL import Image
4
  import torch
5
- from torchvision import transforms
6
  from transformers import ViTFeatureExtractor, ViTForImageClassification
7
- import io
8
-
9
- app = FastAPI()
10
 
11
  # Load model + processor
12
  model = ViTForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
13
  processor = ViTFeatureExtractor.from_pretrained("Falconsai/nsfw_image_detection")
14
 
15
- transform = transforms.Compose([
16
- transforms.Resize((224, 224)),
17
- transforms.ToTensor()
18
- ])
19
-
20
- @app.post("/predict/")
21
- async def predict(file: UploadFile):
22
- image = Image.open(io.BytesIO(await file.read())).convert("RGB")
23
  inputs = processor(images=image, return_tensors="pt")
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
  logits = outputs.logits
27
  predicted_class_idx = logits.argmax(-1).item()
28
  label = model.config.id2label[predicted_class_idx]
29
- return JSONResponse({"label": label})
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
2
  from PIL import Image
3
  import torch
 
4
  from transformers import ViTFeatureExtractor, ViTForImageClassification
 
 
 
5
 
6
  # Load model + processor
7
  model = ViTForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
8
  processor = ViTFeatureExtractor.from_pretrained("Falconsai/nsfw_image_detection")
9
 
10
+ def predict(image):
11
+ # Process image
 
 
 
 
 
 
12
  inputs = processor(images=image, return_tensors="pt")
13
  with torch.no_grad():
14
  outputs = model(**inputs)
15
  logits = outputs.logits
16
  predicted_class_idx = logits.argmax(-1).item()
17
  label = model.config.id2label[predicted_class_idx]
18
+ return label
19
+
20
+ # Gradio Interface
21
+ demo = gr.Interface(
22
+ fn=predict,
23
+ inputs=gr.Image(type="pil", label="Upload Image"),
24
+ outputs=gr.Label(label="Prediction"),
25
+ title="NSFW Image Detection (ViT)",
26
+ description="Upload an image to classify if it's NSFW or SFW."
27
+ )
28
+
29
+ demo.launch()