MASSJ77 commited on
Commit
411c593
·
verified ·
1 Parent(s): 4792eff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import requests
4
+ from PIL import Image
5
+ from fastapi import FastAPI, Header, HTTPException
6
+ from supabase import create_client
7
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
+ import os
9
+
10
+ app = FastAPI()
11
+
12
+ # ===============================
13
+ # सुरक्षा (API Key Protection)
14
+ # ===============================
15
+ API_TOKEN = os.getenv("API_TOKEN")
16
+
17
+ def verify_key(x_api_key: str):
18
+ if API_TOKEN and x_api_key != API_TOKEN:
19
+ raise HTTPException(status_code=403, detail="Unauthorized")
20
+
21
+ # ===============================
22
+ # Supabase
23
+ # ===============================
24
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
25
+ SUPABASE_KEY = os.getenv("SUPABASE_KEY")
26
+
27
+ supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
28
+
29
+ # ===============================
30
+ # Load Model (only once)
31
+ # ===============================
32
+ model_name = "AdamCodd/vit-base-nsfw-detector"
33
+
34
+ processor = AutoImageProcessor.from_pretrained(model_name)
35
+ model = AutoModelForImageClassification.from_pretrained(model_name)
36
+
37
+ # ===============================
38
+ # Image check
39
+ # ===============================
40
+ def check_image(url):
41
+ try:
42
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
43
+ inputs = processor(images=image, return_tensors="pt")
44
+
45
+ with torch.no_grad():
46
+ outputs = model(**inputs)
47
+
48
+ probs = torch.softmax(outputs.logits, dim=1)
49
+ return "explicit" if probs[0][1] > 0.5 else "safe"
50
+ except:
51
+ return "safe"
52
+
53
+ # ===============================
54
+ # Video check
55
+ # ===============================
56
+ def check_video(url, frame_sample_rate=30):
57
+ try:
58
+ cap = cv2.VideoCapture(url)
59
+ frame_count = 0
60
+
61
+ while cap.isOpened():
62
+ ret, frame = cap.read()
63
+ if not ret:
64
+ break
65
+
66
+ if frame_count % frame_sample_rate == 0:
67
+ img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
68
+ inputs = processor(images=img, return_tensors="pt")
69
+
70
+ with torch.no_grad():
71
+ outputs = model(**inputs)
72
+
73
+ probs = torch.softmax(outputs.logits, dim=1)
74
+
75
+ if probs[0][1] > 0.5:
76
+ cap.release()
77
+ return "explicit"
78
+
79
+ frame_count += 1
80
+
81
+ cap.release()
82
+ return "safe"
83
+
84
+ except:
85
+ return "safe"
86
+
87
+ # ===============================
88
+ # 🔥 MAIN ENDPOINT (like /recommend/all)
89
+ # ===============================
90
+ @app.get("/moderate/all")
91
+ async def moderate_all(x_api_key: str = Header(None)):
92
+
93
+ verify_key(x_api_key)
94
+
95
+ results = []
96
+
97
+ # Images
98
+ img_posts = supabase.table("posts").select("*").execute().data
99
+ for post in img_posts:
100
+ result = check_image(post["image_url"])
101
+
102
+ results.append({
103
+ "id": post["id"],
104
+ "type": "image",
105
+ "url": post["image_url"],
106
+ "result": result
107
+ })
108
+
109
+ # Videos
110
+ vid_posts = supabase.table("trendz").select("*").execute().data
111
+ for post in vid_posts:
112
+ result = check_video(post["video_url"])
113
+
114
+ results.append({
115
+ "id": post["id"],
116
+ "type": "video",
117
+ "url": post["video_url"],
118
+ "result": result
119
+ })
120
+
121
+ return {
122
+ "status": "success",
123
+ "count": len(results),
124
+ "data": results
125
+ }