Upload model
Browse files- config.json +2 -2
- modeling_isnet.py +34 -25
config.json
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
-
"
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
"AutoConfig": "configuration_isnet.ISNetConfig",
|
| 7 |
-
"AutoModel": "modeling_isnet.
|
| 8 |
},
|
| 9 |
"in_channels": 3,
|
| 10 |
"out_channels": 1,
|
|
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
+
"ISNetModel"
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
"AutoConfig": "configuration_isnet.ISNetConfig",
|
| 7 |
+
"AutoModel": "modeling_isnet.ISNetModel"
|
| 8 |
},
|
| 9 |
"in_channels": 3,
|
| 10 |
"out_channels": 1,
|
modeling_isnet.py
CHANGED
|
@@ -1,15 +1,34 @@
|
|
| 1 |
import logging
|
| 2 |
-
from
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from transformers import PreTrainedModel
|
|
|
|
| 8 |
|
| 9 |
from .configuration_isnet import ISNetConfig
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
bce_loss = nn.BCELoss(size_average=True)
|
| 14 |
|
| 15 |
|
|
@@ -540,7 +559,7 @@ class ISNetGTEncoder(nn.Module):
|
|
| 540 |
return activated, hidden_states
|
| 541 |
|
| 542 |
|
| 543 |
-
class
|
| 544 |
config_class = ISNetConfig
|
| 545 |
|
| 546 |
def __init__(self, config: ISNetConfig) -> None:
|
|
@@ -582,7 +601,7 @@ class ISNet(PreTrainedModel):
|
|
| 582 |
|
| 583 |
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
| 584 |
|
| 585 |
-
def compute_loss_kl(self, preds, targets, dfs, fs, mode="MSE"):
|
| 586 |
# return muti_loss_fusion(preds,targets)
|
| 587 |
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
|
| 588 |
|
|
@@ -591,25 +610,8 @@ class ISNet(PreTrainedModel):
|
|
| 591 |
return muti_loss_fusion(preds, targets)
|
| 592 |
|
| 593 |
def forward(
|
| 594 |
-
self, pixel_values: torch.Tensor
|
| 595 |
-
) -> Tuple
|
| 596 |
-
Tuple[
|
| 597 |
-
torch.Tensor,
|
| 598 |
-
torch.Tensor,
|
| 599 |
-
torch.Tensor,
|
| 600 |
-
torch.Tensor,
|
| 601 |
-
torch.Tensor,
|
| 602 |
-
torch.Tensor,
|
| 603 |
-
],
|
| 604 |
-
Tuple[
|
| 605 |
-
torch.Tensor,
|
| 606 |
-
torch.Tensor,
|
| 607 |
-
torch.Tensor,
|
| 608 |
-
torch.Tensor,
|
| 609 |
-
torch.Tensor,
|
| 610 |
-
torch.Tensor,
|
| 611 |
-
],
|
| 612 |
-
]:
|
| 613 |
x = pixel_values
|
| 614 |
hx = x
|
| 615 |
|
|
@@ -692,17 +694,24 @@ class ISNet(PreTrainedModel):
|
|
| 692 |
hx5d,
|
| 693 |
hx6,
|
| 694 |
)
|
| 695 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
|
| 697 |
|
| 698 |
def convert_from_checkpoint(
|
| 699 |
repo_id: str, filename: str, config: Optional[ISNetConfig] = None
|
| 700 |
-
) ->
|
| 701 |
from huggingface_hub import hf_hub_download
|
| 702 |
|
| 703 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 704 |
config = config or ISNetConfig()
|
| 705 |
-
model =
|
| 706 |
|
| 707 |
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
| 708 |
state_dict = torch.load(checkpoint_path)
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Literal, Optional, Tuple, Union
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from transformers import PreTrainedModel
|
| 9 |
+
from transformers.utils import ModelOutput
|
| 10 |
|
| 11 |
from .configuration_isnet import ISNetConfig
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ISNetStageOutput(ModelOutput):
|
| 18 |
+
d1: torch.Tensor
|
| 19 |
+
d2: Optional[torch.Tensor] = None
|
| 20 |
+
d3: Optional[torch.Tensor] = None
|
| 21 |
+
d4: Optional[torch.Tensor] = None
|
| 22 |
+
d5: Optional[torch.Tensor] = None
|
| 23 |
+
d6: Optional[torch.Tensor] = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class ISNetModelOutput(ModelOutput):
|
| 28 |
+
activated: ISNetStageOutput
|
| 29 |
+
hidden_states: Optional[ISNetStageOutput] = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
bce_loss = nn.BCELoss(size_average=True)
|
| 33 |
|
| 34 |
|
|
|
|
| 559 |
return activated, hidden_states
|
| 560 |
|
| 561 |
|
| 562 |
+
class ISNetModel(PreTrainedModel):
|
| 563 |
config_class = ISNetConfig
|
| 564 |
|
| 565 |
def __init__(self, config: ISNetConfig) -> None:
|
|
|
|
| 601 |
|
| 602 |
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
| 603 |
|
| 604 |
+
def compute_loss_kl(self, preds, targets, dfs, fs, mode: LossMode = "MSE"):
|
| 605 |
# return muti_loss_fusion(preds,targets)
|
| 606 |
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
|
| 607 |
|
|
|
|
| 610 |
return muti_loss_fusion(preds, targets)
|
| 611 |
|
| 612 |
def forward(
|
| 613 |
+
self, pixel_values: torch.Tensor, return_dict: Optional[bool] = None
|
| 614 |
+
) -> Union[Tuple, ISNetModelOutput]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
x = pixel_values
|
| 616 |
hx = x
|
| 617 |
|
|
|
|
| 694 |
hx5d,
|
| 695 |
hx6,
|
| 696 |
)
|
| 697 |
+
|
| 698 |
+
if not return_dict:
|
| 699 |
+
return activated, hidden_states
|
| 700 |
+
|
| 701 |
+
return ISNetModelOutput(
|
| 702 |
+
activated=ISNetStageOutput(*activated),
|
| 703 |
+
hidden_states=ISNetStageOutput(*hidden_states),
|
| 704 |
+
)
|
| 705 |
|
| 706 |
|
| 707 |
def convert_from_checkpoint(
|
| 708 |
repo_id: str, filename: str, config: Optional[ISNetConfig] = None
|
| 709 |
+
) -> ISNetModel:
|
| 710 |
from huggingface_hub import hf_hub_download
|
| 711 |
|
| 712 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 713 |
config = config or ISNetConfig()
|
| 714 |
+
model = ISNetModel(config)
|
| 715 |
|
| 716 |
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
| 717 |
state_dict = torch.load(checkpoint_path)
|