yaya36095 commited on
Commit
a6174ea
·
verified ·
1 Parent(s): 732be30

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +86 -73
handler.py CHANGED
@@ -1,95 +1,108 @@
1
- from transformers import ViTFeatureExtractor, ViTForImageClassification
2
- from PIL import Image
3
  import torch
4
- import base64
5
- import io
 
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, model_dir):
9
  """
10
- تهيئة النموذج ومعالج الميزات
11
-
12
- المعلمات:
13
- model_dir: مسار مجلد النموذج
14
  """
15
- try:
16
- # تحميل النموذج ومعالج الميزات
17
- self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
18
- self.model = ViTForImageClassification.from_pretrained(model_dir)
19
-
20
- # نقل النموذج إلى وحدة المعالجة المناسبة
21
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- self.model.to(self.device)
23
-
24
- # وضع النموذج في وضع التقييم
25
- self.model.eval()
26
-
27
- print(f"تم تحميل النموذج بنجاح على جهاز {self.device}")
28
- except Exception as e:
29
- print(f"خطأ في تحميل النموذج: {str(e)}")
30
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def __call__(self, data):
33
  """
34
- معالجة البيانات المدخلة وإجراء التنبؤ
35
-
36
- المعلمات:
37
- data: بيانات الصورة المشفرة بـ base64 أو كائن الصورة
38
-
39
- العائد:
40
- dict: نتائج التنبؤ
41
  """
42
  try:
43
- # التحقق من نوع البيانات المدخلة
44
- if isinstance(data, dict) and "image" in data:
45
- # إذا كانت البيانات مشفرة بـ base64
46
- image_data = data["image"]
47
- if isinstance(image_data, str) and image_data.startswith("data:image"):
48
- # إزالة بادئة URL للبيانات
49
- image_data = image_data.split(",")[1]
 
 
 
 
 
50
 
51
- # فك تشفير البيانات
52
- image_bytes = base64.b64decode(image_data)
53
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
54
- elif isinstance(data, bytes):
55
- # إذا كانت البيانات ثنائية
56
- image = Image.open(io.BytesIO(data)).convert("RGB")
57
  else:
58
- return {"error": "تنسيق البيانات غير مدعوم"}
 
59
 
60
- # معالجة الصورة
61
- inputs = self.feature_extractor(images=image, return_tensors="pt")
62
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
63
 
64
- # إجراء التنبؤ
65
  with torch.no_grad():
66
- outputs = self.model(**inputs)
67
- logits = outputs.logits
68
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
69
-
70
- # الحصول على التصنيف ونسبة الثقة
71
- predicted_class_idx = probabilities.argmax().item()
72
- confidence = probabilities[0][predicted_class_idx].item()
73
-
74
- # تحويل الفهرس إلى تسمية
75
- id2label = self.model.config.id2label
76
- predicted_class = id2label[predicted_class_idx]
77
 
78
- # إعداد النتائج
79
- results = {
80
- "prediction": predicted_class,
81
- "confidence": float(confidence),
82
- "is_fake": predicted_class == "fake",
 
 
83
  "probabilities": {
84
- label: float(prob)
85
- for label, prob in zip(id2label.values(), probabilities[0].cpu().numpy())
86
  }
87
  }
88
 
89
- return results
90
-
91
  except Exception as e:
92
- # معالجة الأخطاء
93
- error_message = str(e)
94
- print(f"خطأ في معالجة الصورة: {error_message}")
95
- return {"error": error_message}
 
1
+ import os
 
2
  import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+ import json
7
 
8
  class EndpointHandler:
9
  def __init__(self, model_dir):
10
  """
11
+ Initialize the model for AI image detection
 
 
 
12
  """
13
+ # Set device
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Load model
17
+ self.model = self._load_model(model_dir)
18
+
19
+ # Define transforms
20
+ self.transform = transforms.Compose([
21
+ transforms.Resize((224, 224)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
+ ])
25
+
26
+ # Class names
27
+ self.classes = ["Real Image", "AI-Generated Image"]
28
+
29
+ def _load_model(self, model_dir):
30
+ # Create model architecture
31
+ model = models.efficientnet_v2_s(weights=None)
32
+
33
+ # Recreate classifier exactly as in training
34
+ model.classifier = nn.Sequential(
35
+ nn.Linear(model.classifier[1].in_features, 1024),
36
+ nn.ReLU(),
37
+ nn.Dropout(p=0.3),
38
+ nn.Linear(1024, 512),
39
+ nn.ReLU(),
40
+ nn.Dropout(p=0.3),
41
+ nn.Linear(512, 2)
42
+ )
43
+
44
+ # Load state dict - find the pth file in the directory
45
+ model_path = os.path.join(model_dir, "best_model_improved.pth")
46
+
47
+ if os.path.exists(model_path):
48
+ print(f"Loading model from {model_path}")
49
+ model.load_state_dict(torch.load(model_path, map_location=self.device))
50
+ model.to(self.device)
51
+ model.eval()
52
+ return model
53
+ else:
54
+ raise FileNotFoundError(f"Model file not found at {model_path}. Files in directory: {os.listdir(model_dir)}")
55
 
56
  def __call__(self, data):
57
  """
58
+ Run prediction on the input data
 
 
 
 
 
 
59
  """
60
  try:
61
+ # Parse request data
62
+ if isinstance(data, dict) and "inputs" in data:
63
+ # API format
64
+ input_data = data["inputs"]
65
+ else:
66
+ # Direct image
67
+ input_data = data
68
+
69
+ # Process image
70
+ if isinstance(input_data, str): # Base64 string
71
+ import base64
72
+ from io import BytesIO
73
 
74
+ # Decode base64 image
75
+ image_bytes = base64.b64decode(input_data.split(",")[-1])
76
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
77
+ elif hasattr(input_data, "read"): # File-like object
78
+ image = Image.open(input_data).convert("RGB")
 
79
  else:
80
+ # Assume PIL Image
81
+ image = input_data
82
 
83
+ # Preprocess image
84
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
 
85
 
86
+ # Make prediction
87
  with torch.no_grad():
88
+ outputs = self.model(image_tensor)
89
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
90
+ prediction = torch.argmax(probabilities).item()
 
 
 
 
 
 
 
 
91
 
92
+ # Format results
93
+ real_prob = probabilities[0].item() * 100
94
+ ai_prob = probabilities[1].item() * 100
95
+
96
+ result = {
97
+ "prediction": self.classes[prediction],
98
+ "confidence": float(probabilities[prediction].item()),
99
  "probabilities": {
100
+ "Real Image": float(real_prob),
101
+ "AI-Generated Image": float(ai_prob)
102
  }
103
  }
104
 
105
+ return result
106
+
107
  except Exception as e:
108
+ return {"error": str(e)}