Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.responses import RedirectResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import uvicorn | |
| import os | |
| # Import the FactoryPredictor class | |
| from factory_predictor import FactoryPredictor | |
| app = FastAPI() | |
| # Serve static files | |
| app.mount("/app/static", StaticFiles(directory="static"), name="static") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://127.0.0.1:5173","http://localhost:5173","http://127.0.0.1:7860", | |
| "http://localhost:7860","http://0.0.0.0:7860", "https://santanche-factory-ml.hf.space"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Create a global instance of FactoryPredictor | |
| predictor = FactoryPredictor() | |
| # Global variable to store training status | |
| training_status = "Not started" | |
| # Redirect root to /docs | |
| def root(): | |
| return RedirectResponse(url="/editor/") | |
| def get_editor(): | |
| print(os.path.join("static", "editor", "index.html")) | |
| return FileResponse(os.path.join("static", "editor", "index.html")) | |
| def train_model(): | |
| global training_status | |
| training_status = "In progress" | |
| try: | |
| predictor.train() | |
| training_status = "Completed" | |
| except Exception as e: | |
| training_status = f"Failed: {str(e)}" | |
| async def train(background_tasks: BackgroundTasks): | |
| background_tasks.add_task(train_model) | |
| return {"message": "Training started in the background"} | |
| async def get_training_status(): | |
| return {"status": training_status} | |
| async def predict(temperature: int, pressure: int): | |
| if training_status != "Completed": | |
| raise HTTPException(status_code=400, detail="Model not trained yet") | |
| try: | |
| predicted_diagnosis = predictor.predict(temperature, pressure) | |
| return {"diagnosis": str(predicted_diagnosis)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| async def inform_temperature(value: int): | |
| if training_status != "Completed": | |
| raise HTTPException(status_code=400, detail="Model not trained yet") | |
| try: | |
| predicted_diagnosis = predictor.inform_temperature(value) | |
| return str(predicted_diagnosis) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| async def inform_pressure(value: int): | |
| if training_status != "Completed": | |
| raise HTTPException(status_code=400, detail="Model not trained yet") | |
| try: | |
| predicted_diagnosis = predictor.inform_pressure(value) | |
| return str(predicted_diagnosis) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |