yaya36095 commited on
Commit
2520146
·
verified ·
1 Parent(s): 83cae53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -56
app.py CHANGED
@@ -6,10 +6,11 @@ import gradio as gr
6
  import os
7
  import sys
8
 
9
- # Check if running on Hugging Face
10
- IS_HUGGINGFACE = os.environ.get("SPACE_ID") is not None
 
11
 
12
- # Set up device - important to set it right for each environment
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"Using device: {device}")
15
 
@@ -20,9 +21,8 @@ transform = transforms.Compose([
20
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
21
  ])
22
 
23
- # Load the model
24
  def load_model():
25
- print("Loading model...")
26
  # Create model architecture
27
  model = models.efficientnet_v2_s(weights=None)
28
 
@@ -37,46 +37,42 @@ def load_model():
37
  nn.Linear(512, 2)
38
  )
39
 
40
- # Load the state dict - handle both local and HF paths
41
- try:
42
- model_paths = [
43
- 'best_model_improved.pth', # Local path
44
- '/repository/best_model_improved.pth', # Hugging Face path
45
- os.path.join(os.path.dirname(os.path.abspath(__file__)), 'best_model_improved.pth') # Absolute path
46
- ]
47
-
48
- model_loaded = False
49
- for model_path in model_paths:
50
- if os.path.exists(model_path):
51
- print(f"Loading model from: {model_path}")
 
 
 
52
  model.load_state_dict(torch.load(model_path, map_location=device))
53
  model_loaded = True
54
  break
55
-
56
- if not model_loaded:
57
- print(f"Model not found in any of the expected locations: {model_paths}")
58
- print(f"Current directory contents: {os.listdir('.')}")
59
- if os.path.exists('/repository'):
60
- print(f"Repository directory contents: {os.listdir('/repository')}")
61
- raise FileNotFoundError("Model file not found")
62
-
63
- model.to(device)
64
- model.eval()
65
- print("Model loaded successfully")
66
- return model
67
- except Exception as e:
68
- print(f"Error loading model: {e}")
69
- raise
70
 
71
  # Global model variable
72
  model = None
73
 
74
- # Inference function
75
  def predict_image(img):
76
  global model
77
 
78
  if img is None:
79
- return {"error": "No image provided"}, "Error: No image provided", "Please upload an image"
80
 
81
  try:
82
  # Load model if not already loaded
@@ -105,14 +101,17 @@ def predict_image(img):
105
  # Determine classification
106
  classification = "Real Image" if prediction == 0 else "AI-Generated Image"
107
  confidence = real_prob if prediction == 0 else ai_prob
 
108
 
109
- return result, classification, f"Confidence: {confidence:.2f}%"
110
 
111
  except Exception as e:
 
112
  print(f"Error during prediction: {e}")
113
- return {"error": str(e)}, f"Error: {str(e)}", "Failed to analyze image"
 
114
 
115
- # Define Gradio interface
116
  def create_interface():
117
  with gr.Blocks(title="AI Image Detector", theme=gr.themes.Soft()) as interface:
118
  gr.Markdown("# AI Image Detector")
@@ -124,15 +123,14 @@ def create_interface():
124
  analyze_btn = gr.Button("Analyze Image", variant="primary")
125
 
126
  with gr.Column():
127
- result_label = gr.Label(label="Prediction")
128
- classification = gr.Textbox(label="Classification")
129
- confidence = gr.Textbox(label="Confidence")
130
 
131
  # Set up the click event
132
  analyze_btn.click(
133
  fn=predict_image,
134
  inputs=input_image,
135
- outputs=[result_label, classification, confidence]
136
  )
137
 
138
  gr.Markdown("### How It Works")
@@ -147,18 +145,5 @@ def create_interface():
147
 
148
  # Launch the interface
149
  if __name__ == "__main__":
150
- try:
151
- print("Starting AI Image Detector application...")
152
- interface = create_interface()
153
-
154
- # Different launch options based on environment
155
- if IS_HUGGINGFACE:
156
- print("Running on Hugging Face, launching with share=False")
157
- interface.launch(share=False)
158
- else:
159
- print("Running locally, launching with share=True")
160
- interface.launch(share=True)
161
-
162
- except Exception as e:
163
- print(f"Error starting application: {e}")
164
- sys.exit(1)
 
6
  import os
7
  import sys
8
 
9
+ print("Starting AI Image Detector...")
10
+ print(f"Working directory: {os.getcwd()}")
11
+ print(f"Files in directory: {os.listdir('.')}")
12
 
13
+ # Set up device
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  print(f"Using device: {device}")
16
 
 
21
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
22
  ])
23
 
 
24
  def load_model():
25
+ print("Creating model architecture...")
26
  # Create model architecture
27
  model = models.efficientnet_v2_s(weights=None)
28
 
 
37
  nn.Linear(512, 2)
38
  )
39
 
40
+ # Try to load from multiple possible locations
41
+ possible_paths = [
42
+ "best_model_improved.pth",
43
+ "pytorch_model.bin",
44
+ "/repository/best_model_improved.pth",
45
+ "/repository/pytorch_model.bin",
46
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "best_model_improved.pth"),
47
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "pytorch_model.bin")
48
+ ]
49
+
50
+ model_loaded = False
51
+ for model_path in possible_paths:
52
+ if os.path.exists(model_path):
53
+ print(f"Loading model from: {model_path}")
54
+ try:
55
  model.load_state_dict(torch.load(model_path, map_location=device))
56
  model_loaded = True
57
  break
58
+ except Exception as e:
59
+ print(f"Error loading from {model_path}: {e}")
60
+
61
+ if not model_loaded:
62
+ print("WARNING: Could not load model weights. Using untrained model.")
63
+
64
+ model.to(device)
65
+ model.eval()
66
+ return model
 
 
 
 
 
 
67
 
68
  # Global model variable
69
  model = None
70
 
 
71
  def predict_image(img):
72
  global model
73
 
74
  if img is None:
75
+ return {"Error": "No image provided"}, "Error: Please upload an image"
76
 
77
  try:
78
  # Load model if not already loaded
 
101
  # Determine classification
102
  classification = "Real Image" if prediction == 0 else "AI-Generated Image"
103
  confidence = real_prob if prediction == 0 else ai_prob
104
+ confidence_text = f"Confidence: {confidence:.2f}%"
105
 
106
+ return result, classification + " - " + confidence_text
107
 
108
  except Exception as e:
109
+ import traceback
110
  print(f"Error during prediction: {e}")
111
+ traceback.print_exc()
112
+ return {"error": str(e)}, f"Error: {str(e)}"
113
 
114
+ # Define Gradio interface - simplified for Hugging Face
115
  def create_interface():
116
  with gr.Blocks(title="AI Image Detector", theme=gr.themes.Soft()) as interface:
117
  gr.Markdown("# AI Image Detector")
 
123
  analyze_btn = gr.Button("Analyze Image", variant="primary")
124
 
125
  with gr.Column():
126
+ result_label = gr.Label(label="Prediction Probabilities")
127
+ classification = gr.Textbox(label="Classification Result")
 
128
 
129
  # Set up the click event
130
  analyze_btn.click(
131
  fn=predict_image,
132
  inputs=input_image,
133
+ outputs=[result_label, classification]
134
  )
135
 
136
  gr.Markdown("### How It Works")
 
145
 
146
  # Launch the interface
147
  if __name__ == "__main__":
148
+ interface = create_interface()
149
+ interface.launch()