danielhshi8224 commited on
Commit
5c455e2
·
1 Parent(s): 129de28

hf host model

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -6,6 +6,7 @@ import os
6
 
7
  # Get model path (Windows compatible)
8
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
9
 
10
  # Try different possible filenames
11
  possible_names = ['ConvNextmodel.pth', 'convnextmodel.pth', 'ConvNext_model.pth']
@@ -33,27 +34,29 @@ SPECIES_CATEGORIES = [
33
  ]
34
 
35
  # Load model
36
- print(f"Loading model from: {model_path}")
37
- model = AutoModelForImageClassification.from_pretrained(
38
- 'facebook/convnext-tiny-224',
39
- num_labels=7,
40
- ignore_mismatched_sizes=True
41
- )
 
 
42
 
43
  # Load weights
44
- checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
45
- if isinstance(checkpoint, dict):
46
- if 'model' in checkpoint:
47
- checkpoint = checkpoint['model']
48
- elif 'state_dict' in checkpoint:
49
- checkpoint = checkpoint['state_dict']
50
 
51
- model.load_state_dict(checkpoint, strict=False)
52
- model.eval()
53
 
54
  # Load processor
55
- processor = AutoImageProcessor.from_pretrained('facebook/convnext-tiny-224')
56
- print("✓ Model loaded successfully!")
57
 
58
  def classify_image(image):
59
  """
 
6
 
7
  # Get model path (Windows compatible)
8
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
+ MODEL_ID = "dshi01/my-benthic-classifier"
10
 
11
  # Try different possible filenames
12
  possible_names = ['ConvNextmodel.pth', 'convnextmodel.pth', 'ConvNext_model.pth']
 
34
  ]
35
 
36
  # Load model
37
+ print(f"Loading model from: {MODEL_ID}")
38
+ # model = AutoModelForImageClassification.from_pretrained(
39
+ # 'facebook/convnext-tiny-224',
40
+ # num_labels=7,
41
+ # ignore_mismatched_sizes=True
42
+ # )
43
+ processor=AutoImageProcessor.from_pretrained(MODEL_ID)
44
+ model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
45
 
46
  # Load weights
47
+ # checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
48
+ # if isinstance(checkpoint, dict):
49
+ # if 'model' in checkpoint:
50
+ # checkpoint = checkpoint['model']
51
+ # elif 'state_dict' in checkpoint:
52
+ # checkpoint = checkpoint['state_dict']
53
 
54
+ # model.load_state_dict(checkpoint, strict=False)
55
+ # model.eval()
56
 
57
  # Load processor
58
+ # processor = AutoImageProcessor.from_pretrained('facebook/convnext-tiny-224')
59
+ # print("✓ Model loaded successfully!")
60
 
61
  def classify_image(image):
62
  """