artursultanov commited on
Commit
4ee11bb
·
verified ·
1 Parent(s): ce696e6

Update README.md

Browse files

chore(): update the model usage example

Files changed (1) hide show
  1. README.md +39 -7
README.md CHANGED
@@ -35,19 +35,51 @@ You can load and run this model **directly in PyTorch** **without** installing `
35
 
36
  ```python
37
  import torch
 
 
 
38
 
39
- # 1. Load the model
40
- model = torch.jit.load("cosmoformer_traced.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  model.eval()
42
 
43
- # 2. Inference
44
- # Suppose you have a 3-channel image tensor (1, 3, 224, 224)
45
- dummy_input = torch.randn(1, 3, 224, 224)
 
 
 
 
 
 
 
 
 
46
 
 
47
  with torch.no_grad():
48
- outputs = model(dummy_input)
 
49
 
50
- print(outputs.shape) # e.g., [1, num_classes]
 
51
  ```
52
 
53
  ```
 
35
 
36
  ```python
37
  import torch
38
+ import torchvision.transforms.v2 as v2
39
+ from huggingface_hub import hf_hub_download
40
+ from PIL import Image
41
 
42
+ label_mapping = {
43
+ 0: 'barred_spiral',
44
+ 1: 'edge_on_disk',
45
+ 2: 'featured_without_bar_or_spiral',
46
+ 3: 'irregular',
47
+ 4: 'smooth_cigar',
48
+ 5: 'smooth_inbetween',
49
+ 6: 'smooth_round',
50
+ 7: 'unbarred_spiral'
51
+ }
52
+
53
+ # 1. Define the path to the hugging face repo
54
+ ts_path = hf_hub_download(
55
+ repo_id="artursultanov/cosmoformer-model",
56
+ filename="cosmoformer_traced_cpu.pt"
57
+ )
58
+
59
+ # 2. Load the model from the hugging face repo
60
+ model = torch.jit.load(ts_path, map_location="cpu")
61
  model.eval()
62
 
63
+ # 3. Define image transform to match model's internal representation
64
+ transform = v2.Compose([
65
+ v2.Resize((224, 224)),
66
+ v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
67
+ ])
68
+
69
+ # 4. Load the image
70
+ image_path = "test_image.jpg"
71
+ image = Image.open(image_path).convert("RGB")
72
+
73
+ tensor = transform(image) # shape [3, 224, 224]
74
+ tensor = tensor.unsqueeze(0) shape [1, 3, 224, 224]
75
 
76
+ # 5. Inference
77
  with torch.no_grad():
78
+ output = model(tensor)
79
+ predicted_idx = torch.argmax(output, dim=1).item()
80
 
81
+ predicted_label = label_mapping[predicted_idx]
82
+ print("Predicted class:", predicted_label)
83
  ```
84
 
85
  ```