awss / main.py
F16Sam's picture
Update main.py
c052b32 verified
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import shutil
import uvicorn
from dotenv import load_dotenv
import os
# Load from .env in current directory
load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), ".env"))
fastserver = os.getenv("FAST_SERVER")
nodeserver = os.getenv("NODE_SERVER")
viteserver = os.getenv("VITE_SERVER")
mongoserver = os.getenv("MONGO_URI")
# Download models on startup
#import models.downloadModels
#models.downloadModels.download_all_models()
import shutil
# Ensure the ./models directory exists
os.makedirs("models", exist_ok=True)
# Move/copy each model to the expected path
shutil.copy(model_paths["layer1cnn_aanan.pth"], "models/layer1cnn_aanan.pth")
shutil.copy(model_paths["layer2bio_00"], "models/layer2bio_00")
shutil.copy(model_paths["layer2bio_01"], "models/layer2bio_01")
shutil.copy(model_paths["layer2non_00"], "models/layer2non_00")
shutil.copy(model_paths["layer2non_01"], "models/layer2non_01")
shutil.copy(model_paths["layer3_cnn.keras"], "models/layer3_cnn.keras")
from huggingface_hub import hf_hub_download
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))
model_paths = {
"layer1cnn_aanan.pth": hf_hub_download(repo_id="f16sam/awss-models", filename="layer1cnn_aanan.pth"),
"layer2bio_00": hf_hub_download(repo_id="f16sam/awss-models", filename="layer2bio_00"),
"layer2bio_01": hf_hub_download(repo_id="f16sam/awss-models", filename="layer2bio_01"),
"layer2non_00": hf_hub_download(repo_id="f16sam/awss-models", filename="layer2non_00"),
"layer2non_01": hf_hub_download(repo_id="f16sam/awss-models", filename="layer2non_01"),
"layer3_cnn.keras": hf_hub_download(repo_id="f16sam/awss-models", filename="layer3_cnn.keras")
}
# Reconstructing models
from models.reconstruct_models import reassemble_chunks
reassemble_chunks()
from ModelMain import classify_image
app = FastAPI()
# CORS (optional if needed)
app.add_middleware(
CORSMiddleware,
allow_origins=[url for url in [fastserver, nodeserver, viteserver, mongoserver] if url],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/classify/")
async def classify(file: UploadFile = File(...)):
contents = await file.read()
result = classify_image(contents)
return result
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)