CrypticMonkey3 commited on
Commit
48fb957
·
verified ·
1 Parent(s): 10432dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -1,8 +1,9 @@
1
- import gradio as gr
2
  import cv2
3
  from PIL import Image
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
  import torch
 
 
6
 
7
  # Load model and processor
8
  model_name = "dima806/ai_vs_real_image_detection"
@@ -10,11 +11,15 @@ processor = AutoImageProcessor.from_pretrained(model_name)
10
  model = AutoModelForImageClassification.from_pretrained(model_name)
11
  model.eval()
12
 
 
 
 
 
 
13
  def analyze_video(video):
14
  cap = cv2.VideoCapture(video)
15
  frame_num = 0
16
  frame_interval = 60
17
- batch_size = 12
18
  frames_to_process = []
19
 
20
  while cap.isOpened():
@@ -34,16 +39,15 @@ def analyze_video(video):
34
  if not frames_to_process:
35
  return "No frames extracted."
36
 
37
- # Extract just the images for the processor
38
  frame_numbers, pil_images = zip(*frames_to_process)
39
 
40
- # Batch inference
41
  inputs = processor(images=pil_images, return_tensors="pt")
 
 
42
  with torch.no_grad():
43
  outputs = model(**inputs)
44
  predictions = torch.argmax(outputs.logits, dim=1).tolist()
45
 
46
- # Convert predictions to labels
47
  results = []
48
  for frame_idx, pred in zip(frame_numbers, predictions):
49
  label = model.config.id2label[pred]
@@ -51,12 +55,11 @@ def analyze_video(video):
51
 
52
  return "\n".join(results)
53
 
54
- # Gradio UI
55
  gr.Interface(
56
  fn=analyze_video,
57
  inputs=gr.Video(label="Upload a video"),
58
  outputs=gr.Textbox(label="Detection Results"),
59
  title="AI Frame Detector",
60
- description="Detects whether frames in a video are AI-generated or real.",
61
- gpu=True
62
  ).launch()
 
 
1
  import cv2
2
  from PIL import Image
3
  from transformers import AutoImageProcessor, AutoModelForImageClassification
4
  import torch
5
+ import gradio as gr
6
+ from spaces import GPU # ✅ import this to use the decorator
7
 
8
  # Load model and processor
9
  model_name = "dima806/ai_vs_real_image_detection"
 
11
  model = AutoModelForImageClassification.from_pretrained(model_name)
12
  model.eval()
13
 
14
+ # Use GPU if available
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model = model.to(device)
17
+
18
+ @GPU # ✅ This activates GPU on Hugging Face ZeroGPU Spaces
19
  def analyze_video(video):
20
  cap = cv2.VideoCapture(video)
21
  frame_num = 0
22
  frame_interval = 60
 
23
  frames_to_process = []
24
 
25
  while cap.isOpened():
 
39
  if not frames_to_process:
40
  return "No frames extracted."
41
 
 
42
  frame_numbers, pil_images = zip(*frames_to_process)
43
 
 
44
  inputs = processor(images=pil_images, return_tensors="pt")
45
+ inputs = {k: v.to(device) for k, v in inputs.items()} # ✅ Move inputs to GPU if available
46
+
47
  with torch.no_grad():
48
  outputs = model(**inputs)
49
  predictions = torch.argmax(outputs.logits, dim=1).tolist()
50
 
 
51
  results = []
52
  for frame_idx, pred in zip(frame_numbers, predictions):
53
  label = model.config.id2label[pred]
 
55
 
56
  return "\n".join(results)
57
 
58
+ # Gradio UI (note: NO gpu=True here)
59
  gr.Interface(
60
  fn=analyze_video,
61
  inputs=gr.Video(label="Upload a video"),
62
  outputs=gr.Textbox(label="Detection Results"),
63
  title="AI Frame Detector",
64
+ description="Detects whether frames in a video are AI-generated or real."
 
65
  ).launch()