孙振宇 commited on
Commit
d3f7ab4
·
1 Parent(s): 13f3f6d

Fix: add mmcv fallback in flow_comp.py, remove mmcv from Dockerfile

Browse files
Dockerfile CHANGED
@@ -27,9 +27,7 @@ COPY model_version.json .
27
  RUN mkdir -p resources/checkpoint output working_dir logs data frontend/dist/assets
28
 
29
  # Install Python dependencies
30
- RUN pip install --no-cache-dir setuptools && \
31
- pip install --no-cache-dir --no-build-isolation mmcv && \
32
- uv pip install --system --no-cache \
33
  aiofiles \
34
  aiosqlite \
35
  diffusers \
 
27
  RUN mkdir -p resources/checkpoint output working_dir logs data frontend/dist/assets
28
 
29
  # Install Python dependencies
30
+ RUN uv pip install --system --no-cache \
 
 
31
  aiofiles \
32
  aiosqlite \
33
  diffusers \
sorawm/models/model/modules/flow_comp.py CHANGED
@@ -2,8 +2,34 @@ import numpy as np
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from mmcv.cnn import ConvModule
6
- from mmcv.runner import load_checkpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  from sorawm.configs import PHY_NET_CHECKPOINT_PATH, PHY_NET_CHECKPOINT_REMOTE_URL
9
  from sorawm.utils.download_utils import ensure_model_downloaded
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+
6
+ try:
7
+ from mmcv.cnn import ConvModule
8
+ from mmcv.runner import load_checkpoint
9
+ except ImportError:
10
+ from loguru import logger
11
+ logger.warning("mmcv is not available in flow_comp, using fallback implementation")
12
+
13
+ class ConvModule(nn.Module):
14
+ """Minimal fallback for mmcv.cnn.ConvModule (Conv2d + optional activation)."""
15
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
16
+ padding=0, norm_cfg=None, act_cfg=dict(type="ReLU")):
17
+ super().__init__()
18
+ layers = [nn.Conv2d(in_channels, out_channels, kernel_size,
19
+ stride=stride, padding=padding, bias=True)]
20
+ if act_cfg is not None:
21
+ act_type = act_cfg.get("type", "ReLU")
22
+ layers.append(nn.ReLU(inplace=True) if act_type == "ReLU" else nn.Identity())
23
+ self.block = nn.Sequential(*layers)
24
+
25
+ def forward(self, x):
26
+ return self.block(x)
27
+
28
+ def load_checkpoint(model, filename, strict=False, **kwargs):
29
+ state = torch.load(filename, map_location="cpu")
30
+ if "state_dict" in state:
31
+ state = state["state_dict"]
32
+ model.load_state_dict(state, strict=strict)
33
 
34
  from sorawm.configs import PHY_NET_CHECKPOINT_PATH, PHY_NET_CHECKPOINT_REMOTE_URL
35
  from sorawm.utils.download_utils import ensure_model_downloaded