danielhshi8224 commited on
Commit
54cd5d9
·
1 Parent(s): e897d8b

add application file

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
4
+ from PIL import Image
5
+ import os
6
+
7
+ # Get model path (Windows compatible)
8
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
+
10
+ # Try different possible filenames
11
+ possible_names = ['ConvNextmodel.pth', 'convnextmodel.pth', 'ConvNext_model.pth']
12
+ model_path = None
13
+
14
+ for name in possible_names:
15
+ test_path = os.path.join(BASE_DIR, name)
16
+ if os.path.exists(test_path):
17
+ model_path = test_path
18
+ print(f"✓ Found model: {name}")
19
+ break
20
+
21
+ if model_path is None:
22
+ raise FileNotFoundError(f"Could not find model file. Tried: {possible_names}")
23
+
24
+ # Species categories (7 classes)
25
+ SPECIES_CATEGORIES = [
26
+ 'Eel',
27
+ 'Scallop',
28
+ 'Crab',
29
+ 'Flatfish',
30
+ 'Roundfish',
31
+ 'Skate',
32
+ 'Whelk'
33
+ ]
34
+
35
+ # Load model
36
+ print(f"Loading model from: {model_path}")
37
+ model = AutoModelForImageClassification.from_pretrained(
38
+ 'facebook/convnext-tiny-224',
39
+ num_labels=7,
40
+ ignore_mismatched_sizes=True
41
+ )
42
+
43
+ # Load weights
44
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
45
+ if isinstance(checkpoint, dict):
46
+ if 'model' in checkpoint:
47
+ checkpoint = checkpoint['model']
48
+ elif 'state_dict' in checkpoint:
49
+ checkpoint = checkpoint['state_dict']
50
+
51
+ model.load_state_dict(checkpoint, strict=False)
52
+ model.eval()
53
+
54
+ # Load processor
55
+ processor = AutoImageProcessor.from_pretrained('facebook/convnext-tiny-224')
56
+ print("✓ Model loaded successfully!")
57
+
58
+ def classify_image(image):
59
+ """
60
+ Classify a benthic species image.
61
+
62
+ Args:
63
+ image: PIL Image or numpy array
64
+
65
+ Returns:
66
+ dict: Predictions with species names and confidence scores
67
+ """
68
+ # Convert to PIL if needed
69
+ if not isinstance(image, Image.Image):
70
+ image = Image.fromarray(image).convert('RGB')
71
+
72
+ # Preprocess
73
+ inputs = processor(images=image, return_tensors="pt")
74
+
75
+ # Predict
76
+ with torch.no_grad():
77
+ outputs = model(**inputs)
78
+ logits = outputs.logits
79
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
80
+
81
+ # Create results dictionary for Gradio
82
+ results = {}
83
+ for idx, prob in enumerate(probabilities[0]):
84
+ results[SPECIES_CATEGORIES[idx]] = float(prob)
85
+
86
+ return results
87
+
88
+ # Create Gradio interface
89
+ demo = gr.Interface(
90
+ fn=classify_image,
91
+ inputs=gr.Image(type="pil", label="Upload Underwater Image"),
92
+ outputs=gr.Label(num_top_classes=7, label="Species Classification"),
93
+ title="🌊 BenthicAI - Benthic Species Classifier",
94
+ description="Upload an image of a benthic organism to classify it into one of 7 species categories. Built with ConvNeXT transformer model.",
95
+ examples=[
96
+ [os.path.join("examples", "eel.jpg")],
97
+ [os.path.join("examples", "scallop.jpg")],
98
+ [os.path.join("examples", "crab.jpg")],
99
+ ] if os.path.exists("examples") else None,
100
+ theme=gr.themes.Soft(),
101
+ allow_flagging="never"
102
+ )
103
+
104
+ if __name__ == "__main__":
105
+ demo.launch(
106
+ server_name="0.0.0.0",
107
+ server_port=7860,
108
+ share=True # Set to True to get a public URL
109
+ )