Tensor / app.py
yamunagovindha's picture
Update app.py
22da10a verified
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
print("πŸš€ The script has started running...") # Debugging print
# Define a simple neural network
class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNN, self).__init__()
self.hidden = nn.Linear(input_size, hidden_size)
self.activation = nn.Tanh()
self.output = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.hidden(x)
x = self.activation(x)
x = self.output(x)
return x
# Generate sample data
def generate_data(noise=0):
x = np.linspace(-5, 5, 100)
y = np.sin(x) + noise * np.random.randn(100)
return x, y
# Train the model
def train_model(epochs, learning_rate):
print(f"πŸ”„ Training started with {epochs} epochs and learning rate {learning_rate}...")
model = SimpleNN(1, 4, 1)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
x_train, y_train = generate_data()
x_train = torch.tensor(x_train, dtype=torch.float32).unsqueeze(1)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
losses = []
for epoch in range(epochs):
optimizer.zero_grad()
outputs = model(x_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
losses.append(loss.item())
print("βœ… Training complete. Plotting loss...")
# Plot loss graph
plt.figure(figsize=(5, 3))
plt.plot(losses, label="Training Loss", color="blue")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Neural Network Training Loss")
plt.legend()
plt.grid(True)
loss_plot_path = "loss_plot.png"
plt.savefig(loss_plot_path)
plt.close()
return loss_plot_path
# Gradio interface
iface = gr.Interface(
fn=train_model,
inputs=[
gr.Slider(1, 100, value=10, label="Epochs"),
gr.Slider(0.001, 0.1, value=0.03, label="Learning Rate"),
],
outputs="image",
title="Neural Network Trainer",
description="Train a simple neural network and visualize the loss curve.",
)
print("βœ… Gradio is about to launch...")
iface.launch()