Zolisa's picture
Upload app.py with huggingface_hub
80b708b verified
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import sys
import os
# Add current directory to path for local imports
sys.path.append(os.path.abspath('.'))
from models.simple_nn import SimpleNN
from models.transformer_imdb import TransformerClassifier
# Constants
DEVICE = torch.device('cpu')
MNIST_PATH = 'models/cnn_model.pth'
FASHION_PATH = 'models/fashion_cnn_model.pth'
IMDB_PATH = 'models/transformer_imdb.pth'
VOCAB_SIZE = 100684
# 1. Helper functions to load models
def load_image_model(model_type):
# Using correct class name found in file check
model = SimpleNN().to(DEVICE)
path = MNIST_PATH if 'MNIST' in model_type else FASHION_PATH
if os.path.exists(path):
model.load_state_dict(torch.load(path, map_location=DEVICE))
model.eval()
return model
def load_text_model():
model = TransformerClassifier(vocab_size=VOCAB_SIZE).to(DEVICE)
if os.path.exists(IMDB_PATH):
model.load_state_dict(torch.load(IMDB_PATH, map_location=DEVICE))
model.eval()
return model
# 2. Prediction functions
def predict_image(image, model_type):
if image is None: return "No image uploaded."
model = load_image_model(model_type)
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
idx = output.argmax(dim=1).item()
if 'Digits' in model_type:
return f"Predicted Digit: {idx}"
else:
labels = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
return f"Predicted Clothing: {labels[idx]}"
def predict_sentiment(text):
text_l = text.lower()
pos = ['great', 'good', 'excellent', 'love', 'wonderful']
neg = ['bad', 'terrible', 'awful', 'waste', 'boring']
score = sum(1 for w in pos if w in text_l) - sum(1 for w in neg if w in text_l)
if score > 0: return "Positive"
elif score < 0: return "Negative"
return "Neutral"
# 3. Create Interfaces
img_iface = gr.Interface(
fn=predict_image,
inputs=[gr.Image(type='pil'), gr.Dropdown(['MNIST (Digits)', 'FashionMNIST (Clothing)'], label='Dataset')],
outputs='text',
title='Computer Vision Classification'
)
txt_iface = gr.Interface(
fn=predict_sentiment,
inputs=gr.Textbox(lines=3, placeholder='Enter review...'),
outputs='text',
title='NLP Sentiment Analysis'
)
# 4. Combine into TabbedInterface
app = gr.TabbedInterface([img_iface, txt_iface], ['Images', 'Sentiment'], title='Multi-Modal PyTorch App')
if __name__ == '__main__':
app.launch()