Add phone model (beta), allow models to use different architectures
Browse files- app.py +13 -2
- models/Telephone_SR.th +3 -0
- processAudio.py +5 -5
- src/models/modelFactory.py +3 -3
app.py
CHANGED
|
@@ -18,7 +18,8 @@ with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as
|
|
| 18 |
modelSelect = gr.Dropdown(
|
| 19 |
[
|
| 20 |
["FM Radio Super Resolution","FM_Radio_SR.th"],
|
| 21 |
-
["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"]
|
|
|
|
| 22 |
],
|
| 23 |
label="Select Model:",
|
| 24 |
value="FM_Radio_SR.th",
|
|
@@ -66,9 +67,19 @@ with gr.Blocks(theme=gr.themes.Default().set(body_background_fill="#CCEEFF")) as
|
|
| 66 |
lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
|
| 67 |
if audioData[0] != 44100:
|
| 68 |
lrAudio = resample(lrAudio, audioData[0], 44100)
|
| 69 |
-
|
|
|
|
| 70 |
hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
|
| 71 |
outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
|
| 72 |
return tuple([44100, outAudio])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
layout.launch()
|
|
|
|
| 18 |
modelSelect = gr.Dropdown(
|
| 19 |
[
|
| 20 |
["FM Radio Super Resolution","FM_Radio_SR.th"],
|
| 21 |
+
["AM Radio Super Resolution (Beta)","AM_Radio_SR.th"],
|
| 22 |
+
["Telephone Super Resolution (Beta)","Telephone_SR.th"]
|
| 23 |
],
|
| 24 |
label="Select Model:",
|
| 25 |
value="FM_Radio_SR.th",
|
|
|
|
| 67 |
lrAudio = torch.tensor(audioData[1].copy().astype(np.float32)/32768).transpose(0,1)
|
| 68 |
if audioData[0] != 44100:
|
| 69 |
lrAudio = resample(lrAudio, audioData[0], 44100)
|
| 70 |
+
model_name, experiment_file = getModelInfo(model)
|
| 71 |
+
hrAudio=upscaleAudio(lrAudio, model, model_name=model_name, experiment_file=experiment_file)
|
| 72 |
hrAudio=hrAudio / max(hrAudio.abs().max().item(), 1)
|
| 73 |
outAudio=(hrAudio*32767).numpy().astype(np.int16).transpose(1,0)
|
| 74 |
return tuple([44100, outAudio])
|
| 75 |
+
|
| 76 |
+
def getModelInfo(modelFilename: str):
|
| 77 |
+
if(modelFilename == "FM_Radio_SR.th"):
|
| 78 |
+
return "aero", "aero_441-441_512_256.yaml"
|
| 79 |
+
if(modelFilename == "AM_Radio_SR.th"):
|
| 80 |
+
return "aero", "aero_441-441_512_256.yaml"
|
| 81 |
+
if(modelFilename == "Telephone_SR.th"):
|
| 82 |
+
return "aero", "aero_441-441_512_256.yaml"
|
| 83 |
+
return "aero", "aero_441-441_512_256.yaml"
|
| 84 |
|
| 85 |
layout.launch()
|
models/Telephone_SR.th
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b59e32fccaf83c7e038b8c5e894eeebbce272e9ab00db6b20d45b2fba6e911ca
|
| 3 |
+
size 136533968
|
processAudio.py
CHANGED
|
@@ -20,9 +20,9 @@ SEGMENT_DURATION_SEC = 5
|
|
| 20 |
SEGMENT_OVERLAP_RATIO = 0.25
|
| 21 |
SERIALIZE_KEY_STATE = 'state'
|
| 22 |
|
| 23 |
-
def _load_model(
|
| 24 |
-
checkpoint_file = Path(
|
| 25 |
-
model = modelFactory.get_model(model_name)['generator']
|
| 26 |
package = torch.load(checkpoint_file, 'cpu')
|
| 27 |
if 'state' in package.keys(): #raw model file
|
| 28 |
logger.info(bold(f'Loading model {model_name} from file.'))
|
|
@@ -35,9 +35,9 @@ def crossfade_and_blend(out_clip, in_clip, segment_overlap_samples):
|
|
| 35 |
fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
|
| 36 |
return fade_out(out_clip) + fade_in(in_clip)
|
| 37 |
|
| 38 |
-
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", progress=Progress()):
|
| 39 |
|
| 40 |
-
model = _load_model(checkpoint_file,model_name)
|
| 41 |
device = torch.device('cpu')
|
| 42 |
if torch.cuda.is_available():
|
| 43 |
device = torch.device('cuda')
|
|
|
|
| 20 |
SEGMENT_OVERLAP_RATIO = 0.25
|
| 21 |
SERIALIZE_KEY_STATE = 'state'
|
| 22 |
|
| 23 |
+
def _load_model(checkpoint_filename="FM_Radio_SR.th",model_name="aero",experiment_file="aero_441-441_512_256.yaml"):
|
| 24 |
+
checkpoint_file = Path("models/" + checkpoint_filename)
|
| 25 |
+
model = modelFactory.get_model(model_name,experiment_file)['generator']
|
| 26 |
package = torch.load(checkpoint_file, 'cpu')
|
| 27 |
if 'state' in package.keys(): #raw model file
|
| 28 |
logger.info(bold(f'Loading model {model_name} from file.'))
|
|
|
|
| 35 |
fade_in = torchaudio.transforms.Fade(segment_overlap_samples, 0)
|
| 36 |
return fade_out(out_clip) + fade_in(in_clip)
|
| 37 |
|
| 38 |
+
def upscaleAudio(lr_sig, checkpoint_file: str, sr=44100, hr_sr=44100, model_name="aero", experiment_file="aero_441-441_512_256.yaml", progress=Progress()):
|
| 39 |
|
| 40 |
+
model = _load_model(checkpoint_file,model_name,experiment_file)
|
| 41 |
device = torch.device('cpu')
|
| 42 |
if torch.cuda.is_available():
|
| 43 |
device = torch.device('cuda')
|
src/models/modelFactory.py
CHANGED
|
@@ -2,12 +2,12 @@ from src.models.aero import Aero
|
|
| 2 |
from src.models.seanet import Seanet
|
| 3 |
from yaml import safe_load
|
| 4 |
|
| 5 |
-
def get_model(model_name="aero"):
|
| 6 |
if model_name == 'aero':
|
| 7 |
-
with open("conf/experiment/
|
| 8 |
generator = Aero(**safe_load(f)["aero"])
|
| 9 |
elif model_name == 'seanet':
|
| 10 |
-
with open("conf/experiment/
|
| 11 |
generator = Seanet(**safe_load(f)["seanet"])
|
| 12 |
|
| 13 |
models = {'generator': generator}
|
|
|
|
| 2 |
from src.models.seanet import Seanet
|
| 3 |
from yaml import safe_load
|
| 4 |
|
| 5 |
+
def get_model(model_name="aero", experiment_file="aero_441-441_512_256.yaml"):
|
| 6 |
if model_name == 'aero':
|
| 7 |
+
with open("conf/experiment/" + experiment_file) as f:
|
| 8 |
generator = Aero(**safe_load(f)["aero"])
|
| 9 |
elif model_name == 'seanet':
|
| 10 |
+
with open("conf/experiment/" + experiment_file) as f:
|
| 11 |
generator = Seanet(**safe_load(f)["seanet"])
|
| 12 |
|
| 13 |
models = {'generator': generator}
|