AlexVplle's picture
Upload app.py with huggingface_hub
1c0b13b verified
import gradio as gr
import joblib
import json
import pandas as pd
from fastapi import FastAPI
from fastapi.responses import JSONResponse
import uvicorn
import threading
# Load model and metadata
model = joblib.load("model.pkl")
with open("metadata.json", "r") as f:
metadata = json.load(f)
feature_names = metadata["feature_names"]
def predict(*features):
"""Make prediction with the trained model"""
# Create input DataFrame
input_data = pd.DataFrame([list(features)], columns=feature_names)
# Predict
prediction = model.predict(input_data)[0]
probabilities = model.predict_proba(input_data)[0]
# Format results
prob_dict = {f"Class {i}": prob for i, prob in enumerate(probabilities)}
return f"Predicted Class: {prediction}", prob_dict
def predict_batch_from_url(file_url):
"""Make batch predictions from CSV URL"""
try:
# Download and process CSV
df = pd.read_csv(file_url)
# Check if columns match
if not all(col in df.columns for col in feature_names):
return {"error": f"CSV must contain columns: {feature_names}"}
# Select only the feature columns
X = df[feature_names]
# Make predictions
predictions = model.predict(X)
probabilities = model.predict_proba(X)
# Format results
results = []
for i, (pred, probs) in enumerate(zip(predictions, probabilities)):
prob_dict = {f"Class {j}": float(prob) for j, prob in enumerate(probs)}
results.append({
"prediction": int(pred),
"probabilities": prob_dict
})
return {"predictions": results}
except Exception as e:
return {"error": str(e)}
# FastAPI for batch predictions
app = FastAPI()
@app.post("/api/predict_batch")
async def api_predict_batch(request: dict):
file_url = request.get("file_url")
if not file_url:
return JSONResponse({"error": "file_url is required"}, status_code=400)
result = predict_batch_from_url(file_url)
return JSONResponse(result)
# Gradio interface for single predictions
inputs = [gr.Number(label=name) for name in feature_names]
outputs = [
gr.Textbox(label="Prediction"),
gr.Label(label="Probabilities")
]
interface = gr.Interface(
fn=predict,
inputs=inputs,
outputs=outputs,
title=f"{metadata['model_name']} - ML Classifier",
description=f"Accuracy: {metadata['accuracy']:.4f} | Features: {len(feature_names)}"
)
def run_fastapi():
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
# Start FastAPI in background
fastapi_thread = threading.Thread(target=run_fastapi, daemon=True)
fastapi_thread.start()
# Start Gradio
interface.launch(server_port=7860)