factory-ml / app /server_factory_predictor.py
santanche's picture
fix (app): back to post and https
f8f9b79
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
import uvicorn
import os
# Import the FactoryPredictor class
from factory_predictor import FactoryPredictor
app = FastAPI()
# Serve static files
app.mount("/app/static", StaticFiles(directory="static"), name="static")
app.add_middleware(
CORSMiddleware,
allow_origins=["http://127.0.0.1:5173","http://localhost:5173","http://127.0.0.1:7860",
"http://localhost:7860","http://0.0.0.0:7860", "https://santanche-factory-ml.hf.space"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create a global instance of FactoryPredictor
predictor = FactoryPredictor()
# Global variable to store training status
training_status = "Not started"
# Redirect root to /docs
@app.get("/")
def root():
return RedirectResponse(url="/editor/")
@app.get("/editor/")
def get_editor():
print(os.path.join("static", "editor", "index.html"))
return FileResponse(os.path.join("static", "editor", "index.html"))
def train_model():
global training_status
training_status = "In progress"
try:
predictor.train()
training_status = "Completed"
except Exception as e:
training_status = f"Failed: {str(e)}"
@app.post("/train")
async def train(background_tasks: BackgroundTasks):
background_tasks.add_task(train_model)
return {"message": "Training started in the background"}
@app.get("/training_status")
async def get_training_status():
return {"status": training_status}
@app.get("/predict")
async def predict(temperature: int, pressure: int):
if training_status != "Completed":
raise HTTPException(status_code=400, detail="Model not trained yet")
try:
predicted_diagnosis = predictor.predict(temperature, pressure)
return {"diagnosis": str(predicted_diagnosis)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@app.post("/inform_temperature")
async def inform_temperature(value: int):
if training_status != "Completed":
raise HTTPException(status_code=400, detail="Model not trained yet")
try:
predicted_diagnosis = predictor.inform_temperature(value)
return str(predicted_diagnosis)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@app.post("/inform_pressure")
async def inform_pressure(value: int):
if training_status != "Completed":
raise HTTPException(status_code=400, detail="Model not trained yet")
try:
predicted_diagnosis = predictor.inform_pressure(value)
return str(predicted_diagnosis)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)