Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import sys | |
| import os | |
| # Add current directory to path for local imports | |
| sys.path.append(os.path.abspath('.')) | |
| from models.simple_nn import SimpleNN | |
| from models.transformer_imdb import TransformerClassifier | |
| # Constants | |
| DEVICE = torch.device('cpu') | |
| MNIST_PATH = 'models/cnn_model.pth' | |
| FASHION_PATH = 'models/fashion_cnn_model.pth' | |
| IMDB_PATH = 'models/transformer_imdb.pth' | |
| VOCAB_SIZE = 100684 | |
| # 1. Helper functions to load models | |
| def load_image_model(model_type): | |
| # Using correct class name found in file check | |
| model = SimpleNN().to(DEVICE) | |
| path = MNIST_PATH if 'MNIST' in model_type else FASHION_PATH | |
| if os.path.exists(path): | |
| model.load_state_dict(torch.load(path, map_location=DEVICE)) | |
| model.eval() | |
| return model | |
| def load_text_model(): | |
| model = TransformerClassifier(vocab_size=VOCAB_SIZE).to(DEVICE) | |
| if os.path.exists(IMDB_PATH): | |
| model.load_state_dict(torch.load(IMDB_PATH, map_location=DEVICE)) | |
| model.eval() | |
| return model | |
| # 2. Prediction functions | |
| def predict_image(image, model_type): | |
| if image is None: return "No image uploaded." | |
| model = load_image_model(model_type) | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(num_output_channels=1), | |
| transforms.Resize((28, 28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| input_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| idx = output.argmax(dim=1).item() | |
| if 'Digits' in model_type: | |
| return f"Predicted Digit: {idx}" | |
| else: | |
| labels = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] | |
| return f"Predicted Clothing: {labels[idx]}" | |
| def predict_sentiment(text): | |
| text_l = text.lower() | |
| pos = ['great', 'good', 'excellent', 'love', 'wonderful'] | |
| neg = ['bad', 'terrible', 'awful', 'waste', 'boring'] | |
| score = sum(1 for w in pos if w in text_l) - sum(1 for w in neg if w in text_l) | |
| if score > 0: return "Positive" | |
| elif score < 0: return "Negative" | |
| return "Neutral" | |
| # 3. Create Interfaces | |
| img_iface = gr.Interface( | |
| fn=predict_image, | |
| inputs=[gr.Image(type='pil'), gr.Dropdown(['MNIST (Digits)', 'FashionMNIST (Clothing)'], label='Dataset')], | |
| outputs='text', | |
| title='Computer Vision Classification' | |
| ) | |
| txt_iface = gr.Interface( | |
| fn=predict_sentiment, | |
| inputs=gr.Textbox(lines=3, placeholder='Enter review...'), | |
| outputs='text', | |
| title='NLP Sentiment Analysis' | |
| ) | |
| # 4. Combine into TabbedInterface | |
| app = gr.TabbedInterface([img_iface, txt_iface], ['Images', 'Sentiment'], title='Multi-Modal PyTorch App') | |
| if __name__ == '__main__': | |
| app.launch() | |