|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import joblib |
|
|
import os |
|
|
from sklearn.ensemble import RandomForestClassifier |
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
MODEL_PATH = "rf_model.pkl" |
|
|
DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(): |
|
|
print("Downloading white wine dataset...") |
|
|
df = pd.read_csv(DATA_URL, sep=';') |
|
|
|
|
|
feature_names = [ |
|
|
'fixed acidity', 'volatile acidity', 'citric acid', 'residual sugar', |
|
|
'chlorides', 'free sulfur dioxide', 'total sulfur dioxide', 'density', |
|
|
'pH', 'sulphates', 'alcohol' |
|
|
] |
|
|
|
|
|
X = df[feature_names] |
|
|
y = df['quality'] |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split( |
|
|
X, y, test_size=0.2, random_state=42 |
|
|
) |
|
|
|
|
|
print("Training Random Forest model...") |
|
|
model = RandomForestClassifier( |
|
|
n_estimators=300, |
|
|
max_depth=12, |
|
|
random_state=42 |
|
|
) |
|
|
model.fit(X_train, y_train) |
|
|
|
|
|
joblib.dump(model, MODEL_PATH) |
|
|
print("Model saved as rf_model.pkl") |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(MODEL_PATH): |
|
|
print("Loading existing model...") |
|
|
model = joblib.load(MODEL_PATH) |
|
|
else: |
|
|
model = train_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feature_names = [ |
|
|
'fixed acidity', 'volatile acidity', 'citric acid', 'residual sugar', |
|
|
'chlorides', 'free sulfur dioxide', 'total sulfur dioxide', 'density', |
|
|
'pH', 'sulphates', 'alcohol' |
|
|
] |
|
|
|
|
|
def predict_quality(*inputs): |
|
|
df = pd.DataFrame([inputs], columns=feature_names) |
|
|
prediction = model.predict(df)[0] |
|
|
return f"Predicted Wine Quality: {prediction}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs_ui = [gr.Number(label=name) for name in feature_names] |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_quality, |
|
|
inputs=inputs_ui, |
|
|
outputs=gr.Textbox(label="Prediction"), |
|
|
title="🍾 White Wine Quality Predictor (Trains on HF Space)", |
|
|
description="Random Forest model trained on the UCI White Wine Quality dataset." |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|