| import gradio as gr |
| import requests |
| import torch |
| import torch.nn as nn |
|
|
| import timm |
|
|
| model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True) |
| model.train() |
|
|
| import os |
|
|
| def print_bn(): |
| bn_data = [] |
| for m in model.modules(): |
| if(type(m) is nn.BatchNorm2d): |
| |
| bn_data.extend(m.running_mean.data.numpy().tolist()) |
| bn_data.extend(m.running_var.data.numpy().tolist()) |
| bn_data.append(m.momentum) |
| return bn_data |
|
|
| def greet(image): |
| |
| |
| |
| |
| if(image is None): |
| bn_data = print_bn() |
| return ','.join([f'{x:.10f}' for x in bn_data]) |
| else: |
| print(type(image)) |
| image = torch.tensor(image).float() |
| print(image.min(), image.max()) |
| image = image/255.0 |
| image = image.unsqueeze(0) |
| print(image.shape) |
| image = torch.permute(image, [0,3,1,2]) |
| out = model(image) |
|
|
| |
| return "Hello world!" |
|
|
|
|
|
|
| image = gr.inputs.Image(label="Upload a photo for beautify", shape=(224,224)) |
| iface = gr.Interface(fn=greet, inputs=image, outputs="text") |
| iface.launch() |