MateuszLis commited on
Commit
e4c6687
·
verified ·
1 Parent(s): 862f506

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -15
README.md CHANGED
@@ -22,25 +22,50 @@ license: cc-by-nc-2.0
22
  ## 🖼️ Quick Usage
23
 
24
  ```python
25
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
26
- from PIL import Image
 
27
  import torch
 
 
 
28
 
29
- # Load model & feature extractor
30
- extractor = AutoFeatureExtractor.from_pretrained("your-username/RistoNet")
31
- model = AutoModelForImageClassification.from_pretrained("your-username/RistoNet")
 
 
32
 
33
- # Load an image
34
- image = Image.open("my_dish.jpg")
 
35
 
36
- # Preprocess
37
- inputs = extractor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Predict
40
  with torch.no_grad():
41
- outputs = model(**inputs)
42
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
 
 
 
 
43
 
44
- # Top class
45
- predicted_class = predictions.argmax().item()
46
- print(f"Predicted class: {predicted_class}")
 
22
  ## 🖼️ Quick Usage
23
 
24
  ```python
25
+ ## 🔎 Inference Example
26
+
27
+ ```python
28
  import torch
29
+ from PIL import Image
30
+ from torchvision import models, transforms
31
+ from huggingface_hub import hf_hub_download
32
 
33
+ # ------------------------------
34
+ # 1. Load model
35
+ # ------------------------------
36
+ MODEL_REPO = "Orkidee/RistoNet"
37
+ MODEL_FILE = "ristonet.pth"
38
 
39
+ model = models.efficientnet_b0(weights=None) # no pretrained weights
40
+ num_features = model.classifier[1].in_features
41
+ model.classifier[1] = torch.nn.Linear(num_features, 2) # 2 classes
42
 
43
+ # Download weights from Hub and load
44
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
45
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
46
+ model.eval()
47
+
48
+ # ------------------------------
49
+ # 2. Define preprocessing
50
+ # ------------------------------
51
+ transform = transforms.Compose([
52
+ transforms.Resize((224, 224)),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
55
+ std=[0.229, 0.224, 0.225])
56
+ ])
57
+
58
+ # ------------------------------
59
+ # 3. Run inference on an image
60
+ # ------------------------------
61
+ image = Image.open("my_food.jpg").convert("RGB")
62
+ input_tensor = transform(image).unsqueeze(0) # add batch dim
63
 
 
64
  with torch.no_grad():
65
+ outputs = model(input_tensor)
66
+ probs = torch.nn.functional.softmax(outputs, dim=1)
67
+ predicted_class = probs.argmax().item()
68
+
69
+ print("Predicted class:", predicted_class)
70
+ print("Probabilities:", probs.numpy())
71