programmingghost commited on
Commit
bdbd54d
·
verified ·
1 Parent(s): 4032b46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -3,45 +3,62 @@ from transformers import AutoModelForImageClassification, ViTImageProcessor
3
  from PIL import Image
4
  import torch
5
 
 
6
  # Load model once (global)
 
7
  model_id = "jacoballessio/ai-image-detect-distilled"
8
 
9
  processor = ViTImageProcessor.from_pretrained(model_id)
10
  model = AutoModelForImageClassification.from_pretrained(
11
  model_id,
12
- torch_dtype=torch.float32
 
13
  )
14
- model.eval()
15
 
 
16
  device = "cpu"
17
  model.to(device)
18
 
19
 
 
 
 
20
  def predict(image: Image.Image):
21
  if image is None:
22
- return "Please upload an image", ""
23
 
24
  # Preprocess
25
  inputs = processor(image, return_tensors="pt").to(device)
26
 
 
27
  with torch.no_grad():
28
  outputs = model(**inputs)
29
 
30
- # Softmax
31
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
32
  confidence = probs.max().item()
33
  predicted_label = model.config.id2label[probs.argmax().item()]
34
 
 
 
 
 
 
 
 
 
35
  # Result text
36
  if predicted_label.lower() == "fake":
37
  result = f"⚠️ AI-GENERATED\nConfidence: {confidence:.3f}"
38
  else:
39
  result = f"✅ REAL IMAGE\nConfidence: {confidence:.3f}"
40
 
41
- return result, probs.squeeze().tolist()
42
 
43
 
 
44
  # UI
 
45
  app = gr.Interface(
46
  fn=predict,
47
  inputs=gr.Image(type="pil", label="Upload Image"),
@@ -53,5 +70,9 @@ app = gr.Interface(
53
  description="Upload an image to check if it's AI-generated or real."
54
  )
55
 
 
 
 
 
56
  if __name__ == "__main__":
57
  app.launch(server_name="0.0.0.0", server_port=7860)
 
3
  from PIL import Image
4
  import torch
5
 
6
+ # -------------------------------
7
  # Load model once (global)
8
+ # -------------------------------
9
  model_id = "jacoballessio/ai-image-detect-distilled"
10
 
11
  processor = ViTImageProcessor.from_pretrained(model_id)
12
  model = AutoModelForImageClassification.from_pretrained(
13
  model_id,
14
+ dtype=torch.float32,
15
+ low_cpu_mem_usage=True
16
  )
 
17
 
18
+ model.eval()
19
  device = "cpu"
20
  model.to(device)
21
 
22
 
23
+ # -------------------------------
24
+ # Prediction function
25
+ # -------------------------------
26
  def predict(image: Image.Image):
27
  if image is None:
28
+ return "Please upload an image", None
29
 
30
  # Preprocess
31
  inputs = processor(image, return_tensors="pt").to(device)
32
 
33
+ # Inference
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
 
37
+ # Probabilities
38
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
39
  confidence = probs.max().item()
40
  predicted_label = model.config.id2label[probs.argmax().item()]
41
 
42
+ # Convert to dict for Gradio Label
43
+ labels = model.config.id2label
44
+ scores = probs.squeeze().tolist()
45
+
46
+ confidence_dict = {
47
+ labels[i]: float(scores[i]) for i in range(len(scores))
48
+ }
49
+
50
  # Result text
51
  if predicted_label.lower() == "fake":
52
  result = f"⚠️ AI-GENERATED\nConfidence: {confidence:.3f}"
53
  else:
54
  result = f"✅ REAL IMAGE\nConfidence: {confidence:.3f}"
55
 
56
+ return result, confidence_dict
57
 
58
 
59
+ # -------------------------------
60
  # UI
61
+ # -------------------------------
62
  app = gr.Interface(
63
  fn=predict,
64
  inputs=gr.Image(type="pil", label="Upload Image"),
 
70
  description="Upload an image to check if it's AI-generated or real."
71
  )
72
 
73
+
74
+ # -------------------------------
75
+ # Run app
76
+ # -------------------------------
77
  if __name__ == "__main__":
78
  app.launch(server_name="0.0.0.0", server_port=7860)