| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| def matplot(plotter, x, y): | |
| plt.figure(1) | |
| plotter.set_xdata(np.append(plotter.get_xdata(), x)) | |
| plotter.set_ydata(np.append(plotter.get_ydata(), y)) | |
| plotter.axes.relim() | |
| plotter.axes.autoscale() | |
| plt.draw() | |
| plt.pause(0.01) | |
| def plotresult(model, test_data): | |
| device = next(model.parameters()).device | |
| x, y, _ = test_data.__getitem__(int(np.random.rand()*test_data.__len__())) | |
| x = x.to(device) | |
| x = x.reshape(1,x.shape[0],x.shape[1]) | |
| y = y.reshape(1,y.shape[0],y.shape[1]) | |
| with torch.no_grad(): | |
| y_ = model(x).cpu() | |
| x = x.cpu() | |
| plt.figure(2) | |
| plt.clf() | |
| plt.subplot(3,2,1) | |
| plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,2) | |
| plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,3) | |
| plt.plot(y_[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,4) | |
| plt.plot(y_[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,5) | |
| plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,6) | |
| plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.draw() | |
| plt.pause(0.01) | |
| def plotresulti(model, test_data): | |
| device = next(model.parameters()).device | |
| y, x, _ = test_data.__getitem__(int(np.random.rand()*test_data.__len__())) | |
| x = x.to(device) | |
| x = x.reshape(1,x.shape[0],x.shape[1]) | |
| y = y.reshape(1,y.shape[0],y.shape[1]) | |
| with torch.no_grad(): | |
| y_ = model(x).cpu() | |
| x = x.cpu() | |
| plt.figure(2) | |
| plt.clf() | |
| plt.subplot(3,2,1) | |
| plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,2) | |
| plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,3) | |
| plt.plot(y_[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,4) | |
| plt.plot(y_[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,5) | |
| plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,6) | |
| plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.draw() | |
| plt.pause(0.01) | |
| def plotresulti2(model1, model2, test_data): | |
| device = next(model1.parameters()).device | |
| y, x, _ = test_data.__getitem__(int(np.random.rand()*test_data.__len__())) | |
| x = x.to(device) | |
| x = x.reshape(1,x.shape[0],x.shape[1]) | |
| y = y.reshape(1,y.shape[0],y.shape[1]) | |
| with torch.no_grad(): | |
| x = model1(x) | |
| y_ = model2(x).cpu() | |
| x = x.cpu() | |
| plt.figure(2) | |
| plt.clf() | |
| plt.subplot(3,2,1) | |
| plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,2) | |
| plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,3) | |
| plt.plot(y_[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,4) | |
| plt.plot(y_[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,5) | |
| plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(3,2,6) | |
| plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.draw() | |
| plt.pause(0.01) | |
| def plotresult_(model, x0): | |
| device = next(model.parameters()).device | |
| phaseComp = np.exp(1j*np.pi*np.arange(len(x0))) * 2e5 | |
| x = -np.conj(np.complex64(np.fft.fftshift(np.fft.ifft(x0)) * phaseComp)) | |
| xreal = np.real(x) | |
| ximag = np.imag(x) | |
| x = torch.tensor(np.c_[xreal,ximag].T).to(device) | |
| x = x.reshape(1,x.shape[0],x.shape[1]) | |
| with torch.no_grad(): | |
| y = model(x).cpu() | |
| x = x.cpu() | |
| plt.figure(3) | |
| plt.clf() | |
| plt.subplot(2,2,1) | |
| plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(2,2,2) | |
| plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(2,2,3) | |
| plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(2,2,4) | |
| plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.draw() | |
| plt.pause(0.01) | |
| return x[0][0].numpy()+1j*x[0][1].numpy(), y[0][0].numpy()+1j*y[0][1].numpy() | |
| def plotresulti_(model, x0): | |
| device = next(model.parameters()).device | |
| x = np.complex64(x0); | |
| xreal = np.real(x) | |
| ximag = np.imag(x) | |
| x = torch.tensor(np.c_[xreal,ximag].T).to(device) | |
| x = x.reshape(1,x.shape[0],x.shape[1]) | |
| with torch.no_grad(): | |
| y = model(x).cpu() | |
| x = x.cpu() | |
| plt.figure(3) | |
| plt.clf() | |
| plt.subplot(2,2,1) | |
| plt.plot(x[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(2,2,2) | |
| plt.plot(x[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(2,2,3) | |
| plt.plot(y[0][0].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.subplot(2,2,4) | |
| plt.plot(y[0][1].numpy(),'.-',markersize=0.5,linewidth=0.1) | |
| plt.draw() | |
| plt.pause(0.01) | |
| return y[0][0].numpy()+1j*y[0][1].numpy() | |
| def GetEnergy(dataU): | |
| nt = 2**11 | |
| dt = 1.0/96e9 | |
| T = dt*nt | |
| t = np.arange(nt) * dt - T / 2 | |
| P0 = 2.0/1.1e-3 | |
| E = np.zeros(dataU.shape[1]) | |
| for i in tqdm(range(dataU.shape[1]), desc='GetEnergy'): | |
| u = dataU[:, i] | |
| E[i] = np.real(np.trapz(u * np.conj(u), t)) | |
| return E * P0 | |
| def GetPulseByEnergyRange(dataset, energy_low, energy_high): | |
| E = GetEnergy(dataset.u) | |
| return dataset.u[:, (E >= energy_low) & (E < energy_high)] | |
| def GetPulseByEnergyLogRange(dataset, energy_low, energy_high): | |
| E = np.log10(GetEnergy(dataset.u)) | |
| return dataset.u[:, (E >= energy_low) & (E < energy_high)] | |
| def ScanEnergyRange(dataset, nrange=100): | |
| E = GetEnergy(dataset.u) | |
| Emin, Emax = np.min(E), np.max(E) | |
| dE = (Emax - Emin) / nrange | |
| nE = [] | |
| E1 = Emin | |
| for ii in tqdm(range(nrange), desc='ScanEnergyRange'): | |
| E2 = E1 + dE | |
| nE.append(len(E[(E >= E1) & (E < E2)])) | |
| E1 = E2 | |
| return np.array(nE) | |
| def ScanEnergyRangeLog(dataset, nrange=100): | |
| E = np.log10(GetEnergy(dataset.u)) | |
| Emin, Emax = np.min(E), np.max(E) | |
| dE = (Emax - Emin) / nrange | |
| print(Emin, Emax, dE) | |
| nE = [] | |
| E1 = Emin | |
| for ii in tqdm(range(nrange), desc='ScanEnergyRangeLog'): | |
| E2 = E1 + dE | |
| nE.append(len(E[(E >= E1) & (E < E2)])) | |
| E1 = E2 | |
| return np.array(nE) | |
| def SeperatePulseByEnergyRanges(dataset, nrange): | |
| e = GetEnergy(dataset.u) | |
| Emin, Emax = np.min(e), np.max(e) | |
| dE = (Emax - Emin) / nrange | |
| U = [] | |
| Q = [] | |
| E = [] | |
| E1 = Emin | |
| for ii in tqdm(range(nrange), desc='SeperatePulseByEnergyLogRanges'): | |
| E2 = E1 + dE | |
| U.append(dataset.u[:, (e >= E1) & (e < E2)]) | |
| Q.append(dataset.q[:, (e >= E1) & (e < E2)]) | |
| E.append(e[(e >= E1) & (e < E2)]) | |
| E1 = E2 | |
| return U, Q, (E, Emin, Emax, dE) | |
| def SeperatePulseByEnergyLogRanges(dataset, nrange): | |
| e = np.log10(GetEnergy(dataset.u)) | |
| Emin, Emax = np.min(e), np.max(e) | |
| dE = (Emax - Emin) / nrange | |
| U = [] | |
| Q = [] | |
| E = [] | |
| E1 = Emin | |
| for ii in tqdm(range(nrange), desc='SeperatePulseByEnergyLogRanges'): | |
| E2 = E1 + dE | |
| U.append(dataset.u[:, (e >= E1) & (e < E2)]) | |
| Q.append(dataset.q[:, (e >= E1) & (e < E2)]) | |
| E.append(e[(e >= E1) & (e < E2)]) | |
| E1 = E2 | |
| return U, Q, (E, Emin, Emax, dE) | |
| def SeperatePulseByEnergyRanges_(u, q, nrange): | |
| e = GetEnergy(u) | |
| Emin, Emax = np.min(e), np.max(e) | |
| dE = (Emax - Emin) / nrange | |
| U = [] | |
| Q = [] | |
| E = [] | |
| E1 = Emin | |
| for ii in tqdm(range(nrange), desc='SeperatePulseByEnergyRanges'): | |
| E2 = E1 + dE | |
| U.append(u[:, (e >= E1) & (e < E2)]) | |
| Q.append(q[:, (e >= E1) & (e < E2)]) | |
| E.append(e[(e >= E1) & (e < E2)]) | |
| E1 = E2 | |
| return U, Q, (E, Emin, Emax, dE) | |
| def SeperatePulseByEnergyLogRanges_(u, q, nrange): | |
| e = np.log10(GetEnergy(u)) | |
| Emin, Emax = np.min(e), np.max(e) | |
| dE = (Emax - Emin) / nrange | |
| U = [] | |
| Q = [] | |
| E = [] | |
| E1 = Emin | |
| for ii in tqdm(range(nrange), desc='SeperatePulseByEnergyLogRanges'): | |
| E2 = E1 + dE | |
| U.append(u[:, (e >= E1) & (e < E2)]) | |
| Q.append(q[:, (e >= E1) & (e < E2)]) | |
| E.append(e[(e >= E1) & (e < E2)]) | |
| E1 = E2 | |
| return U, Q, (E, Emin, Emax, dE) | |
| def prepareQ(q, normalization = True): | |
| qtemp = q | |
| qnorm = 1.0 | |
| if normalization: | |
| qnorm = np.max(np.abs(qtemp)) | |
| qtemp /= qnorm | |
| qreal = np.real(qtemp) | |
| qimag = np.imag(qtemp) | |
| Q = torch.tensor(np.c_[qreal,qimag].T) | |
| Q = Q.reshape((1,) + Q.shape) | |
| return Q, qnorm | |
| def prepareU(u, normalization = True): | |
| phaseComp = np.exp(1j*np.pi*np.arange(len(u))) * 2e5 | |
| utemp = -np.conj(np.complex64(np.fft.fftshift(np.fft.ifft(u)) * phaseComp)) | |
| return prepareQ(utemp, normalization) | |
| def valid_(model, u, q, loss_fn, transformU, transformQ, normalization = True, forward=False): | |
| model.eval() | |
| device = next(model.parameters()).device | |
| losses = np.zeros(u.shape[1]) | |
| with torch.no_grad(): | |
| for ii in tqdm(range(u.shape[1]), desc='Valid'): | |
| x, _ = transformU(u[:, ii], normalization) | |
| y, _ = transformQ(q[:, ii], normalization) | |
| x, y = x.to(device), y.to(device) | |
| if forward: | |
| y, x = x, y | |
| with torch.cuda.amp.autocast(): | |
| y_hat = model(x) | |
| losses[ii] = loss_fn(y_hat, y).detach().cpu().numpy() | |
| loss = np.mean(losses) | |
| return loss, losses | |
| def valid(model, dataset, loss_fn, batch_size, forward=False): | |
| model.eval() | |
| device = next(model.parameters()).device | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
| losses = [] | |
| pbar = tqdm(dataloader, disable=False) | |
| with torch.no_grad(): | |
| for x, y, _ in pbar: | |
| pbar.set_description("Valid") | |
| x, y = x.to(device), y.to(device) | |
| if forward: | |
| y, x = x, y | |
| y_hat = model(x) | |
| losses.append(loss_fn(y_hat, y).detach().cpu().numpy()) | |
| losses = np.concatenate(losses) | |
| loss = np.sqrt(np.mean(losses**2)) | |
| return loss, losses | |
| def valid2(model1, model2, dataset, loss_fn, batch_size, forward=False): | |
| model1.eval() | |
| model2.eval() | |
| device = next(model1.parameters()).device | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
| losses = [] | |
| pbar = tqdm(dataloader, disable=False) | |
| for y, x, _ in pbar: | |
| pbar.set_description("Valid") | |
| x, y = x.to(device), y.to(device) | |
| if forward: | |
| y, x = x, y | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(): | |
| x_hat = model1(x) | |
| y_hat = model2(x_hat) | |
| losses.append(loss_fn(y_hat, y).detach().cpu().numpy()) | |
| losses = np.concatenate(losses) | |
| loss = np.sqrt(np.mean(losses**2)) | |
| return loss, losses | |
| def LossVsEnergy_(model, u, q, loss_fn, nrange=100): | |
| U, Q, E = SeperatePulseByEnergyRanges_(u, q, nrange) | |
| loss = [] | |
| for ii in tqdm(range(len(U)), desc='LossVsEnergy'): | |
| loss.append( valid(model, U[ii], Q[ii], loss_fn, prepareU, prepareQ)[0] ) | |
| return loss, E | |
| def LossVsEnergy(model, dataset, loss_fn, nrange=100): | |
| U, Q, E = SeperatePulseByEnergyRanges(dataset, nrange) | |
| loss = [] | |
| for ii in tqdm(range(len(U)), desc='LossVsEnergy'): | |
| loss.append( valid(model, U[ii], Q[ii], loss_fn, prepareU, prepareQ)[0] ) | |
| return loss, E | |
| def LossVsEnergyLog_(model, u, q, loss_fn, nrange=100): | |
| U, Q, E = SeperatePulseByEnergyLogRanges_(u, q, nrange) | |
| loss = [] | |
| for ii in tqdm(range(len(U)), desc='LossVsEnergyLog'): | |
| loss.append( valid(model, U[ii], Q[ii], loss_fn, prepareU, prepareQ)[0] ) | |
| return loss, E | |
| def LossVsEnergyLog(model, dataset, loss_fn, nrange=100): | |
| U, Q, E = SeperatePulseByEnergyLogRanges(dataset, nrange) | |
| loss = [] | |
| for ii in tqdm(range(len(U)), desc='LossVsEnergyLog'): | |
| loss.append( valid(model, U[ii], Q[ii], loss_fn, prepareU, prepareQ)[0] ) | |
| return loss, E | |