NNNFT / utils.py
zwenqi's picture
Upload 8 files
2e4db5f verified
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
def matplot(plotter, x, y):
plt.figure(1)
plotter.set_xdata(np.append(plotter.get_xdata(), x))
plotter.set_ydata(np.append(plotter.get_ydata(), y))
plotter.axes.relim()
plotter.axes.autoscale()
plt.draw()
plt.pause(0.01)
def plotresult(model, test_data):
device = next(model.parameters()).device
x, y, _ = test_data.__getitem__(int(np.random.rand()*test_data.__len__()))
x = x.to(device)
x = x.reshape(1,x.shape[0],x.shape[1])
y = y.reshape(1,y.shape[0],y.shape[1])
with torch.no_grad():
y_ = model(x).cpu()
x = x.cpu()
plt.figure(2)
plt.clf()
plt.subplot(3,2,1)
plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,2)
plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,3)
plt.plot(y_[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,4)
plt.plot(y_[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,5)
plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,6)
plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.draw()
plt.pause(0.01)
def plotresulti(model, test_data):
device = next(model.parameters()).device
y, x, _ = test_data.__getitem__(int(np.random.rand()*test_data.__len__()))
x = x.to(device)
x = x.reshape(1,x.shape[0],x.shape[1])
y = y.reshape(1,y.shape[0],y.shape[1])
with torch.no_grad():
y_ = model(x).cpu()
x = x.cpu()
plt.figure(2)
plt.clf()
plt.subplot(3,2,1)
plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,2)
plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,3)
plt.plot(y_[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,4)
plt.plot(y_[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,5)
plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,6)
plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.draw()
plt.pause(0.01)
def plotresulti2(model1, model2, test_data):
device = next(model1.parameters()).device
y, x, _ = test_data.__getitem__(int(np.random.rand()*test_data.__len__()))
x = x.to(device)
x = x.reshape(1,x.shape[0],x.shape[1])
y = y.reshape(1,y.shape[0],y.shape[1])
with torch.no_grad():
x = model1(x)
y_ = model2(x).cpu()
x = x.cpu()
plt.figure(2)
plt.clf()
plt.subplot(3,2,1)
plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,2)
plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,3)
plt.plot(y_[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,4)
plt.plot(y_[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,5)
plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(3,2,6)
plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.draw()
plt.pause(0.01)
def plotresult_(model, x0):
device = next(model.parameters()).device
phaseComp = np.exp(1j*np.pi*np.arange(len(x0))) * 2e5
x = -np.conj(np.complex64(np.fft.fftshift(np.fft.ifft(x0)) * phaseComp))
xreal = np.real(x)
ximag = np.imag(x)
x = torch.tensor(np.c_[xreal,ximag].T).to(device)
x = x.reshape(1,x.shape[0],x.shape[1])
with torch.no_grad():
y = model(x).cpu()
x = x.cpu()
plt.figure(3)
plt.clf()
plt.subplot(2,2,1)
plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(2,2,2)
plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(2,2,3)
plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(2,2,4)
plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.draw()
plt.pause(0.01)
return x[0][0].numpy()+1j*x[0][1].numpy(), y[0][0].numpy()+1j*y[0][1].numpy()
def plotresulti_(model, x0):
device = next(model.parameters()).device
x = np.complex64(x0);
xreal = np.real(x)
ximag = np.imag(x)
x = torch.tensor(np.c_[xreal,ximag].T).to(device)
x = x.reshape(1,x.shape[0],x.shape[1])
with torch.no_grad():
y = model(x).cpu()
x = x.cpu()
plt.figure(3)
plt.clf()
plt.subplot(2,2,1)
plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(2,2,2)
plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(2,2,3)
plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.subplot(2,2,4)
plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1)
plt.draw()
plt.pause(0.01)
return y[0][0].numpy()+1j*y[0][1].numpy()
def GetEnergy(dataU):
nt = 2**11
dt = 1.0/96e9
T = dt*nt
t = np.arange(nt) * dt - T / 2
P0 = 2.0/1.1e-3
E = np.zeros(dataU.shape[1])
for i in tqdm(range(dataU.shape[1]), desc='GetEnergy'):
u = dataU[:, i]
E[i] = np.real(np.trapz(u * np.conj(u), t))
return E * P0
def GetPulseByEnergyRange(dataset, energy_low, energy_high):
E = GetEnergy(dataset.u)
return dataset.u[:, (E >= energy_low) & (E < energy_high)]
def GetPulseByEnergyLogRange(dataset, energy_low, energy_high):
E = np.log10(GetEnergy(dataset.u))
return dataset.u[:, (E >= energy_low) & (E < energy_high)]
def ScanEnergyRange(dataset, nrange=100):
E = GetEnergy(dataset.u)
Emin, Emax = np.min(E), np.max(E)
dE = (Emax - Emin) / nrange
nE = []
E1 = Emin
for ii in tqdm(range(nrange), desc='ScanEnergyRange'):
E2 = E1 + dE
nE.append(len(E[(E >= E1) & (E < E2)]))
E1 = E2
return np.array(nE)
def ScanEnergyRangeLog(dataset, nrange=100):
E = np.log10(GetEnergy(dataset.u))
Emin, Emax = np.min(E), np.max(E)
dE = (Emax - Emin) / nrange
print(Emin, Emax, dE)
nE = []
E1 = Emin
for ii in tqdm(range(nrange), desc='ScanEnergyRangeLog'):
E2 = E1 + dE
nE.append(len(E[(E >= E1) & (E < E2)]))
E1 = E2
return np.array(nE)
def SeperatePulseByEnergyRanges(dataset, nrange):
e = GetEnergy(dataset.u)
Emin, Emax = np.min(e), np.max(e)
dE = (Emax - Emin) / nrange
U = []
Q = []
E = []
E1 = Emin
for ii in tqdm(range(nrange), desc='SeperatePulseByEnergyLogRanges'):
E2 = E1 + dE
U.append(dataset.u[:, (e >= E1) & (e < E2)])
Q.append(dataset.q[:, (e >= E1) & (e < E2)])
E.append(e[(e >= E1) & (e < E2)])
E1 = E2
return U, Q, (E, Emin, Emax, dE)
def SeperatePulseByEnergyLogRanges(dataset, nrange):
e = np.log10(GetEnergy(dataset.u))
Emin, Emax = np.min(e), np.max(e)
dE = (Emax - Emin) / nrange
U = []
Q = []
E = []
E1 = Emin
for ii in tqdm(range(nrange), desc='SeperatePulseByEnergyLogRanges'):
E2 = E1 + dE
U.append(dataset.u[:, (e >= E1) & (e < E2)])
Q.append(dataset.q[:, (e >= E1) & (e < E2)])
E.append(e[(e >= E1) & (e < E2)])
E1 = E2
return U, Q, (E, Emin, Emax, dE)
def SeperatePulseByEnergyRanges_(u, q, nrange):
e = GetEnergy(u)
Emin, Emax = np.min(e), np.max(e)
dE = (Emax - Emin) / nrange
U = []
Q = []
E = []
E1 = Emin
for ii in tqdm(range(nrange), desc='SeperatePulseByEnergyRanges'):
E2 = E1 + dE
U.append(u[:, (e >= E1) & (e < E2)])
Q.append(q[:, (e >= E1) & (e < E2)])
E.append(e[(e >= E1) & (e < E2)])
E1 = E2
return U, Q, (E, Emin, Emax, dE)
def SeperatePulseByEnergyLogRanges_(u, q, nrange):
e = np.log10(GetEnergy(u))
Emin, Emax = np.min(e), np.max(e)
dE = (Emax - Emin) / nrange
U = []
Q = []
E = []
E1 = Emin
for ii in tqdm(range(nrange), desc='SeperatePulseByEnergyLogRanges'):
E2 = E1 + dE
U.append(u[:, (e >= E1) & (e < E2)])
Q.append(q[:, (e >= E1) & (e < E2)])
E.append(e[(e >= E1) & (e < E2)])
E1 = E2
return U, Q, (E, Emin, Emax, dE)
def prepareQ(q, normalization = True):
qtemp = q
qnorm = 1.0
if normalization:
qnorm = np.max(np.abs(qtemp))
qtemp /= qnorm
qreal = np.real(qtemp)
qimag = np.imag(qtemp)
Q = torch.tensor(np.c_[qreal,qimag].T)
Q = Q.reshape((1,) + Q.shape)
return Q, qnorm
def prepareU(u, normalization = True):
phaseComp = np.exp(1j*np.pi*np.arange(len(u))) * 2e5
utemp = -np.conj(np.complex64(np.fft.fftshift(np.fft.ifft(u)) * phaseComp))
return prepareQ(utemp, normalization)
def valid_(model, u, q, loss_fn, transformU, transformQ, normalization = True, forward=False):
model.eval()
device = next(model.parameters()).device
losses = np.zeros(u.shape[1])
with torch.no_grad():
for ii in tqdm(range(u.shape[1]), desc='Valid'):
x, _ = transformU(u[:, ii], normalization)
y, _ = transformQ(q[:, ii], normalization)
x, y = x.to(device), y.to(device)
if forward:
y, x = x, y
with torch.cuda.amp.autocast():
y_hat = model(x)
losses[ii] = loss_fn(y_hat, y).detach().cpu().numpy()
loss = np.mean(losses)
return loss, losses
def valid(model, dataset, loss_fn, batch_size, forward=False):
model.eval()
device = next(model.parameters()).device
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
losses = []
pbar = tqdm(dataloader, disable=False)
with torch.no_grad():
for x, y, _ in pbar:
pbar.set_description("Valid")
x, y = x.to(device), y.to(device)
if forward:
y, x = x, y
y_hat = model(x)
losses.append(loss_fn(y_hat, y).detach().cpu().numpy())
losses = np.concatenate(losses)
loss = np.sqrt(np.mean(losses**2))
return loss, losses
def valid2(model1, model2, dataset, loss_fn, batch_size, forward=False):
model1.eval()
model2.eval()
device = next(model1.parameters()).device
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
losses = []
pbar = tqdm(dataloader, disable=False)
for y, x, _ in pbar:
pbar.set_description("Valid")
x, y = x.to(device), y.to(device)
if forward:
y, x = x, y
with torch.no_grad():
with torch.cuda.amp.autocast():
x_hat = model1(x)
y_hat = model2(x_hat)
losses.append(loss_fn(y_hat, y).detach().cpu().numpy())
losses = np.concatenate(losses)
loss = np.sqrt(np.mean(losses**2))
return loss, losses
def LossVsEnergy_(model, u, q, loss_fn, nrange=100):
U, Q, E = SeperatePulseByEnergyRanges_(u, q, nrange)
loss = []
for ii in tqdm(range(len(U)), desc='LossVsEnergy'):
loss.append( valid(model, U[ii], Q[ii], loss_fn, prepareU, prepareQ)[0] )
return loss, E
def LossVsEnergy(model, dataset, loss_fn, nrange=100):
U, Q, E = SeperatePulseByEnergyRanges(dataset, nrange)
loss = []
for ii in tqdm(range(len(U)), desc='LossVsEnergy'):
loss.append( valid(model, U[ii], Q[ii], loss_fn, prepareU, prepareQ)[0] )
return loss, E
def LossVsEnergyLog_(model, u, q, loss_fn, nrange=100):
U, Q, E = SeperatePulseByEnergyLogRanges_(u, q, nrange)
loss = []
for ii in tqdm(range(len(U)), desc='LossVsEnergyLog'):
loss.append( valid(model, U[ii], Q[ii], loss_fn, prepareU, prepareQ)[0] )
return loss, E
def LossVsEnergyLog(model, dataset, loss_fn, nrange=100):
U, Q, E = SeperatePulseByEnergyLogRanges(dataset, nrange)
loss = []
for ii in tqdm(range(len(U)), desc='LossVsEnergyLog'):
loss.append( valid(model, U[ii], Q[ii], loss_fn, prepareU, prepareQ)[0] )
return loss, E