Jared commited on
Commit ·
26b9bd6
1
Parent(s): 7194535
Fix Python usage examples with working code
Browse files
README.md
CHANGED
|
@@ -65,10 +65,11 @@ pip install open-clip-torch torch pillow
|
|
| 65 |
### Python Usage
|
| 66 |
|
| 67 |
```python
|
|
|
|
| 68 |
from calorie_clip import CalorieCLIP
|
| 69 |
|
| 70 |
-
# Load model
|
| 71 |
-
model = CalorieCLIP.from_pretrained("
|
| 72 |
|
| 73 |
# Predict calories
|
| 74 |
calories = model.predict("food_photo.jpg")
|
|
@@ -79,6 +80,42 @@ images = ["breakfast.jpg", "lunch.jpg", "dinner.jpg"]
|
|
| 79 |
results = model.predict_batch(images)
|
| 80 |
```
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
### Command Line
|
| 83 |
|
| 84 |
```bash
|
|
|
|
| 65 |
### Python Usage
|
| 66 |
|
| 67 |
```python
|
| 68 |
+
# Clone or download this repo first, then:
|
| 69 |
from calorie_clip import CalorieCLIP
|
| 70 |
|
| 71 |
+
# Load model from local directory
|
| 72 |
+
model = CalorieCLIP.from_pretrained(".")
|
| 73 |
|
| 74 |
# Predict calories
|
| 75 |
calories = model.predict("food_photo.jpg")
|
|
|
|
| 80 |
results = model.predict_batch(images)
|
| 81 |
```
|
| 82 |
|
| 83 |
+
### Direct Usage (no wrapper)
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
import torch
|
| 87 |
+
import open_clip
|
| 88 |
+
from PIL import Image
|
| 89 |
+
|
| 90 |
+
# Load CLIP
|
| 91 |
+
clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
|
| 92 |
+
checkpoint = torch.load('calorie_clip.pt', map_location='cpu', weights_only=False)
|
| 93 |
+
clip.load_state_dict(checkpoint['clip_state'], strict=False)
|
| 94 |
+
|
| 95 |
+
# Load regression head
|
| 96 |
+
import torch.nn as nn
|
| 97 |
+
class RegressionHead(nn.Module):
|
| 98 |
+
def __init__(self):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.net = nn.Sequential(
|
| 101 |
+
nn.Linear(512, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.4),
|
| 102 |
+
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3),
|
| 103 |
+
nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, 1)
|
| 104 |
+
)
|
| 105 |
+
def forward(self, x): return self.net(x)
|
| 106 |
+
|
| 107 |
+
head = RegressionHead()
|
| 108 |
+
head.load_state_dict(checkpoint['regressor_state'])
|
| 109 |
+
clip.eval(); head.eval()
|
| 110 |
+
|
| 111 |
+
# Predict
|
| 112 |
+
img = preprocess(Image.open('food.jpg')).unsqueeze(0)
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
features = clip.encode_image(img)
|
| 115 |
+
calories = head(features).item()
|
| 116 |
+
print(f"{calories:.0f} calories")
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
### Command Line
|
| 120 |
|
| 121 |
```bash
|