Spaces:
Running
on
Zero
Running
on
Zero
| # -*- encoding: utf-8 -*- | |
| ''' | |
| @File : sr_group.py | |
| @Time : 2022/04/02 01:17:21 | |
| @Author : Ming Ding | |
| @Contact : dm18@mails.tsinghua.edu.cn | |
| ''' | |
| # here put the import lib | |
| import os | |
| import sys | |
| import math | |
| import random | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from SwissArmyTransformer.resources import auto_create | |
| from .direct_sr import DirectSuperResolution | |
| from .iterative_sr import IterativeSuperResolution | |
| class SRGroup: | |
| def __init__(self, args, home_path=None,): | |
| dsr_path = auto_create('cogview2-dsr', path=home_path) | |
| itersr_path = auto_create('cogview2-itersr', path=home_path) | |
| dsr = DirectSuperResolution(args, dsr_path) | |
| itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer) | |
| self.dsr = dsr | |
| self.itersr = itersr | |
| def sr_base(self, img_tokens, txt_tokens): | |
| assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2 | |
| batch_size = img_tokens.shape[0] | |
| txt_len = txt_tokens.shape[-1] | |
| if len(txt_tokens.shape) == 1: | |
| txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) | |
| sred_tokens = self.dsr(txt_tokens, img_tokens) | |
| iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone()) | |
| return iter_tokens[-batch_size:] | |
| # def sr_patch(self, img_tokens, txt_tokens): | |
| # assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2 | |
| # batch_size = img_tokens.shape[0] * 9 | |
| # txt_len = txt_tokens.shape[-1] | |
| # if len(txt_tokens.shape) == 1: | |
| # txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) | |
| # img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400) | |
| # iter_tokens = self.sr_base(img_tokens, txt_tokens) | |
| # return iter_tokens |