import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import gradio as gr import sys import os # Add current directory to path for local imports sys.path.append(os.path.abspath('.')) from models.simple_nn import SimpleNN from models.transformer_imdb import TransformerClassifier # Constants DEVICE = torch.device('cpu') MNIST_PATH = 'models/cnn_model.pth' FASHION_PATH = 'models/fashion_cnn_model.pth' IMDB_PATH = 'models/transformer_imdb.pth' VOCAB_SIZE = 100684 # 1. Helper functions to load models def load_image_model(model_type): # Using correct class name found in file check model = SimpleNN().to(DEVICE) path = MNIST_PATH if 'MNIST' in model_type else FASHION_PATH if os.path.exists(path): model.load_state_dict(torch.load(path, map_location=DEVICE)) model.eval() return model def load_text_model(): model = TransformerClassifier(vocab_size=VOCAB_SIZE).to(DEVICE) if os.path.exists(IMDB_PATH): model.load_state_dict(torch.load(IMDB_PATH, map_location=DEVICE)) model.eval() return model # 2. Prediction functions def predict_image(image, model_type): if image is None: return "No image uploaded." model = load_image_model(model_type) transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) input_tensor = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): output = model(input_tensor) idx = output.argmax(dim=1).item() if 'Digits' in model_type: return f"Predicted Digit: {idx}" else: labels = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] return f"Predicted Clothing: {labels[idx]}" def predict_sentiment(text): text_l = text.lower() pos = ['great', 'good', 'excellent', 'love', 'wonderful'] neg = ['bad', 'terrible', 'awful', 'waste', 'boring'] score = sum(1 for w in pos if w in text_l) - sum(1 for w in neg if w in text_l) if score > 0: return "Positive" elif score < 0: return "Negative" return "Neutral" # 3. Create Interfaces img_iface = gr.Interface( fn=predict_image, inputs=[gr.Image(type='pil'), gr.Dropdown(['MNIST (Digits)', 'FashionMNIST (Clothing)'], label='Dataset')], outputs='text', title='Computer Vision Classification' ) txt_iface = gr.Interface( fn=predict_sentiment, inputs=gr.Textbox(lines=3, placeholder='Enter review...'), outputs='text', title='NLP Sentiment Analysis' ) # 4. Combine into TabbedInterface app = gr.TabbedInterface([img_iface, txt_iface], ['Images', 'Sentiment'], title='Multi-Modal PyTorch App') if __name__ == '__main__': app.launch()