Nightfury16 commited on
Commit
20ec8a2
·
1 Parent(s): 5492c02

fixed convnextv2 bug

Browse files
Files changed (3) hide show
  1. Dockerfile +3 -4
  2. app.py +8 -4
  3. main.py +8 -6
Dockerfile CHANGED
@@ -1,15 +1,14 @@
1
  FROM python:3.9-slim
2
 
3
- ENV TRANSFORMERS_CACHE=/data/.cache
 
 
4
 
5
  WORKDIR /code
6
 
7
  COPY requirements.txt .
8
-
9
  RUN pip install --no-cache-dir -r requirements.txt
10
-
11
  COPY . .
12
 
13
  EXPOSE 7860
14
-
15
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.9-slim
2
 
3
+ ENV TRANSFORMERS_CACHE=/data/.cache/transformers
4
+ ENV HF_HOME=/data/.cache/huggingface
5
+ ENV MPLCONFIGDIR=/data/.cache/matplotlib
6
 
7
  WORKDIR /code
8
 
9
  COPY requirements.txt .
 
10
  RUN pip install --no-cache-dir -r requirements.txt
 
11
  COPY . .
12
 
13
  EXPOSE 7860
 
14
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,20 +1,25 @@
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import yaml
4
  from torchvision import models, transforms
5
  from PIL import Image
6
  import gradio as gr
7
- import os
8
  from transformers import ConvNextV2ForImageClassification
9
  from typing import Dict, Tuple
10
 
11
  MODEL_CHECKPOINTS = {
12
- "ConvNeXt Tiny (Best)": "checkpoints/convnext_v2_tiny_best.pth",
13
  "EfficientNet-B0": "checkpoints/effnet_b0_best.pth",
14
  "EfficientNet-B3": "checkpoints/effnet_b3_best.pth",
15
  "Vision Transformer B-16": "checkpoints/vit_b_16_best.pth"
16
  }
17
- DEFAULT_MODEL_NAME = "ConvNeXt Tiny (Best)"
18
 
19
  MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -48,7 +53,6 @@ def get_model(model_name: str, num_classes: int) -> nn.Module:
48
  return model
49
 
50
  def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]:
51
- # (Same load_checkpoint function as in main.py)
52
  if not os.path.exists(checkpoint_path):
53
  raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
54
  checkpoint = torch.load(checkpoint_path, map_location=device)
 
1
+ import os
2
+
3
+ os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers'
4
+ os.environ['HF_HOME'] = '/data/.cache/huggingface'
5
+ os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib'
6
+
7
  import torch
8
  import torch.nn as nn
9
  import yaml
10
  from torchvision import models, transforms
11
  from PIL import Image
12
  import gradio as gr
 
13
  from transformers import ConvNextV2ForImageClassification
14
  from typing import Dict, Tuple
15
 
16
  MODEL_CHECKPOINTS = {
17
+ "ConvNeXt tiny (Best)": "checkpoints/convnext_v2_tiny_best.pth",
18
  "EfficientNet-B0": "checkpoints/effnet_b0_best.pth",
19
  "EfficientNet-B3": "checkpoints/effnet_b3_best.pth",
20
  "Vision Transformer B-16": "checkpoints/vit_b_16_best.pth"
21
  }
22
+ DEFAULT_MODEL_NAME = "ConvNeXt tiny (Best)"
23
 
24
  MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
53
  return model
54
 
55
  def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]:
 
56
  if not os.path.exists(checkpoint_path):
57
  raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
58
  checkpoint = torch.load(checkpoint_path, map_location=device)
main.py CHANGED
@@ -1,10 +1,15 @@
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import yaml
4
  from torchvision import models, transforms
5
  from PIL import Image
6
  import gradio as gr
7
- import os
8
  import base64
9
  import io
10
  import time
@@ -16,14 +21,13 @@ from fastapi.middleware.cors import CORSMiddleware
16
  from pydantic import BaseModel
17
  from transformers import ConvNextV2ForImageClassification
18
 
19
-
20
  MODEL_CHECKPOINTS = {
21
- "ConvNeXt Tiny (Best)": "checkpoints/convnext_v2_tiny_best.pth",
22
  "EfficientNet-B0": "checkpoints/effnet_b0_best.pth",
23
  "EfficientNet-B3": "checkpoints/effnet_b3_best.pth",
24
  "Vision Transformer B-16": "checkpoints/vit_b_16_best.pth"
25
  }
26
- DEFAULT_MODEL_NAME = "ConvNeXt Tiny (Best)"
27
 
28
  GPU_MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
29
  CPU_MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
@@ -79,11 +83,9 @@ gpu_device = torch.device("cuda") if torch.cuda.is_available() else None
79
  for display_name, ckpt_path in MODEL_CHECKPOINTS.items():
80
  if os.path.exists(ckpt_path):
81
  print(f"Loading '{display_name}'...")
82
- # Load for CPU (always)
83
  cpu_model, idx_to_class = load_checkpoint(ckpt_path, cpu_device)
84
  CPU_MODELS[display_name] = (cpu_model, idx_to_class)
85
  print(f" > Loaded '{display_name}' for CPU.")
86
- # Load for GPU if available
87
  if gpu_device:
88
  gpu_model, _ = load_checkpoint(ckpt_path, gpu_device)
89
  GPU_MODELS[display_name] = (gpu_model, idx_to_class)
 
1
+ import os
2
+
3
+ os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers'
4
+ os.environ['HF_HOME'] = '/data/.cache/huggingface'
5
+ os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib'
6
+
7
  import torch
8
  import torch.nn as nn
9
  import yaml
10
  from torchvision import models, transforms
11
  from PIL import Image
12
  import gradio as gr
 
13
  import base64
14
  import io
15
  import time
 
21
  from pydantic import BaseModel
22
  from transformers import ConvNextV2ForImageClassification
23
 
 
24
  MODEL_CHECKPOINTS = {
25
+ "ConvNeXt tiny (Best)": "checkpoints/convnext_v2_tiny_best.pth",
26
  "EfficientNet-B0": "checkpoints/effnet_b0_best.pth",
27
  "EfficientNet-B3": "checkpoints/effnet_b3_best.pth",
28
  "Vision Transformer B-16": "checkpoints/vit_b_16_best.pth"
29
  }
30
+ DEFAULT_MODEL_NAME = "ConvNeXt tiny (Best)"
31
 
32
  GPU_MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
33
  CPU_MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
 
83
  for display_name, ckpt_path in MODEL_CHECKPOINTS.items():
84
  if os.path.exists(ckpt_path):
85
  print(f"Loading '{display_name}'...")
 
86
  cpu_model, idx_to_class = load_checkpoint(ckpt_path, cpu_device)
87
  CPU_MODELS[display_name] = (cpu_model, idx_to_class)
88
  print(f" > Loaded '{display_name}' for CPU.")
 
89
  if gpu_device:
90
  gpu_model, _ = load_checkpoint(ckpt_path, gpu_device)
91
  GPU_MODELS[display_name] = (gpu_model, idx_to_class)