Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import torchvision.transforms as T | |
| from transformers import AutoFeatureExtractor, ResNetForImageClassification | |
| import timm | |
| feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-101") | |
| model = ResNetForImageClassification.from_pretrained("microsoft/resnet-101") | |
| model.eval() | |
| import os | |
| def print_bn(): | |
| bn_data = [] | |
| for m in model.modules(): | |
| if(type(m) is nn.BatchNorm2d): | |
| # print(m.momentum) | |
| bn_data.extend(m.running_mean.data.numpy().tolist()) | |
| bn_data.extend(m.running_var.data.numpy().tolist()) | |
| bn_data.append(m.momentum) | |
| print(len(bn_data)) | |
| # bn_data.extend(model.resnet.embedder.embedder.convolution.weight.data.reshape(-1).numpy().tolist()) | |
| # print(model.resnet.embedder.embedder.convolution.weight.data.numpy().tolist()) | |
| return bn_data | |
| def update_bn(image): | |
| cursor_im = 0 | |
| image = T.Resize((90,90))(image) | |
| image = image.reshape(-1) | |
| for m in model.modules(): | |
| if(type(m) is nn.BatchNorm2d): | |
| if(cursor_im < image.shape[0]): | |
| M = m.running_mean.data.shape[0] | |
| if(cursor_im+M < image.shape[0]): | |
| m.running_mean.data = image[cursor_im:cursor_im+M] | |
| cursor_im += M | |
| print(cursor_im,':',cursor_im+M) | |
| else: | |
| m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:] | |
| break | |
| return | |
| def greet(image): | |
| if(image is None): | |
| bn_data = print_bn() | |
| return ','.join([f'{x:.2f}' for x in bn_data]) | |
| else: | |
| # conv_layer = model.resnet.embedder.embedder.convolution | |
| # conv_layer.weight.data = torch.ones_like(conv_layer.weight.data) | |
| print(type(image)) | |
| image = torch.tensor(image).float() | |
| print(image.min(), image.max()) | |
| image = image/255.0 | |
| image = image.unsqueeze(0) | |
| image = torch.permute(image, [0,3,1,2]) | |
| update_bn(image) | |
| print(image.shape) | |
| out = model(pixel_values=image) | |
| return "Hello world!" | |
| image = gr.inputs.Image(label="Upload a photo for beauty", shape=(224,224)) | |
| out_image = gr.inputs.Image(label='Yes, it becomes better.') | |
| iface = gr.Interface(fn=greet, inputs=image, outputs='text') | |
| iface.launch() |