CrypticMonkey3 commited on
Commit
1d764ab
·
verified ·
1 Parent(s): 48fb957

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -13
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import cv2
 
2
  from PIL import Image
3
  from transformers import AutoImageProcessor, AutoModelForImageClassification
4
  import torch
5
  import gradio as gr
6
- from spaces import GPU # import this to use the decorator
7
 
8
  # Load model and processor
9
  model_name = "dima806/ai_vs_real_image_detection"
@@ -15,22 +16,57 @@ model.eval()
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  model = model.to(device)
17
 
18
- @GPU # This activates GPU on Hugging Face ZeroGPU Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def analyze_video(video):
20
  cap = cv2.VideoCapture(video)
21
- frame_num = 0
22
- frame_interval = 60
 
 
 
 
 
 
23
  frames_to_process = []
 
 
24
 
25
  while cap.isOpened():
26
  ret, frame = cap.read()
27
- if not ret:
28
  break
29
 
30
- if frame_num % frame_interval == 0:
31
- rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
32
  pil_image = Image.fromarray(rgb)
33
  frames_to_process.append((frame_num, pil_image))
 
34
 
35
  frame_num += 1
36
 
@@ -42,24 +78,48 @@ def analyze_video(video):
42
  frame_numbers, pil_images = zip(*frames_to_process)
43
 
44
  inputs = processor(images=pil_images, return_tensors="pt")
45
- inputs = {k: v.to(device) for k, v in inputs.items()} # ✅ Move inputs to GPU if available
46
 
47
  with torch.no_grad():
48
  outputs = model(**inputs)
49
- predictions = torch.argmax(outputs.logits, dim=1).tolist()
 
 
 
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  results = []
52
- for frame_idx, pred in zip(frame_numbers, predictions):
53
  label = model.config.id2label[pred]
54
- results.append(f"Frame {frame_idx}: {label}")
 
 
 
 
 
55
 
56
  return "\n".join(results)
57
 
58
- # Gradio UI (note: NO gpu=True here)
59
  gr.Interface(
60
  fn=analyze_video,
61
  inputs=gr.Video(label="Upload a video"),
62
  outputs=gr.Textbox(label="Detection Results"),
63
  title="AI Frame Detector",
64
- description="Detects whether frames in a video are AI-generated or real."
65
  ).launch()
 
1
  import cv2
2
+ import numpy as np
3
  from PIL import Image
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
  import torch
6
  import gradio as gr
7
+ from spaces import GPU # Hugging Face ZeroGPU decorator
8
 
9
  # Load model and processor
10
  model_name = "dima806/ai_vs_real_image_detection"
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model = model.to(device)
18
 
19
+ # === Preprocessing Function ===
20
+ def preprocess_frame(frame: np.ndarray) -> np.ndarray:
21
+ # Resize for consistency
22
+ frame = cv2.resize(frame, (512, 512), interpolation=cv2.INTER_AREA)
23
+
24
+ # Denoise slightly
25
+ frame = cv2.GaussianBlur(frame, (3, 3), 0)
26
+
27
+ # Sharpen
28
+ kernel_sharpen = np.array([[0, -1, 0],
29
+ [-1, 5, -1],
30
+ [0, -1, 0]])
31
+ frame = cv2.filter2D(frame, -1, kernel_sharpen)
32
+
33
+ # Brightness/contrast adjustment
34
+ frame = cv2.convertScaleAbs(frame, alpha=1.1, beta=10)
35
+
36
+ # Histogram equalization
37
+ yuv = cv2.cvtColor(frame, cv2.COLOR_BGR2YUV)
38
+ yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0])
39
+ frame = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR)
40
+
41
+ return frame
42
+
43
+ # === Main Video Analysis Function ===
44
+ @GPU
45
  def analyze_video(video):
46
  cap = cv2.VideoCapture(video)
47
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
+ max_frames_to_process = 10
49
+
50
+ if total_frames == 0:
51
+ return "Could not read video frames."
52
+
53
+ # Evenly sample 10 frames
54
+ frame_indices = [int(i * total_frames / max_frames_to_process) for i in range(max_frames_to_process)]
55
  frames_to_process = []
56
+ current_index = 0
57
+ frame_num = 0
58
 
59
  while cap.isOpened():
60
  ret, frame = cap.read()
61
+ if not ret or current_index >= len(frame_indices):
62
  break
63
 
64
+ if frame_num == frame_indices[current_index]:
65
+ processed = preprocess_frame(frame)
66
+ rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
67
  pil_image = Image.fromarray(rgb)
68
  frames_to_process.append((frame_num, pil_image))
69
+ current_index += 1
70
 
71
  frame_num += 1
72
 
 
78
  frame_numbers, pil_images = zip(*frames_to_process)
79
 
80
  inputs = processor(images=pil_images, return_tensors="pt")
81
+ inputs = {k: v.to(device) for k, v in inputs.items()}
82
 
83
  with torch.no_grad():
84
  outputs = model(**inputs)
85
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
86
+ predictions = torch.argmax(probs, dim=1).tolist()
87
+
88
+ # Find index for "AI" class
89
+ ai_class_index = [k for k, v in model.config.id2label.items() if "AI" in v.upper()][0]
90
+ ai_scores = probs[:, ai_class_index].tolist()
91
 
92
+ # === Smart Verdict Logic ===
93
+ ai_conf_threshold = 0.7 # Must be at least 70% confident to count
94
+ min_ai_frames = 3 # Require 3 or more confident AI frames
95
+ avg_score = sum(ai_scores) / len(ai_scores)
96
+ confident_ai_frames = sum(1 for score in ai_scores if score >= ai_conf_threshold)
97
+
98
+ if confident_ai_frames >= min_ai_frames:
99
+ verdict = "Video is likely AI-generated."
100
+ elif avg_score >= 0.5:
101
+ verdict = "Uncertain — some signs of AI generation but not enough evidence."
102
+ else:
103
+ verdict = "Video is likely REAL."
104
+
105
+ # === Output Formatting ===
106
  results = []
107
+ for frame_idx, pred, ai_score in zip(frame_numbers, predictions, ai_scores):
108
  label = model.config.id2label[pred]
109
+ confidence = ai_score * 100
110
+ results.append(f"Frame {frame_idx}: {label} (AI Confidence: {confidence:.1f}%)")
111
+
112
+ results.append(f"\nConfident AI-flagged frames: {confident_ai_frames}")
113
+ results.append(f"Average AI Confidence: {avg_score * 100:.1f}%")
114
+ results.append(f"Final Verdict: {verdict}")
115
 
116
  return "\n".join(results)
117
 
118
+ # === Gradio UI ===
119
  gr.Interface(
120
  fn=analyze_video,
121
  inputs=gr.Video(label="Upload a video"),
122
  outputs=gr.Textbox(label="Detection Results"),
123
  title="AI Frame Detector",
124
+ description="Analyzes 10 frames from a video with preprocessing to detect if it's AI-generated or real."
125
  ).launch()