| from fastapi import FastAPI, File, UploadFile |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import shutil |
| import uvicorn |
| from dotenv import load_dotenv |
| import os |
|
|
| |
| load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), ".env")) |
|
|
| fastserver = os.getenv("FAST_SERVER") |
| nodeserver = os.getenv("NODE_SERVER") |
| viteserver = os.getenv("VITE_SERVER") |
| mongoserver = os.getenv("MONGO_URI") |
|
|
| |
| |
| |
|
|
| import shutil |
|
|
| |
| os.makedirs("models", exist_ok=True) |
|
|
| |
| shutil.copy(model_paths["layer1cnn_aanan.pth"], "models/layer1cnn_aanan.pth") |
| shutil.copy(model_paths["layer2bio_00"], "models/layer2bio_00") |
| shutil.copy(model_paths["layer2bio_01"], "models/layer2bio_01") |
| shutil.copy(model_paths["layer2non_00"], "models/layer2non_00") |
| shutil.copy(model_paths["layer2non_01"], "models/layer2non_01") |
| shutil.copy(model_paths["layer3_cnn.keras"], "models/layer3_cnn.keras") |
|
|
| from huggingface_hub import hf_hub_download |
| from huggingface_hub import login |
| login(token=os.getenv("HF_TOKEN")) |
| model_paths = { |
| "layer1cnn_aanan.pth": hf_hub_download(repo_id="f16sam/awss-models", filename="layer1cnn_aanan.pth"), |
| "layer2bio_00": hf_hub_download(repo_id="f16sam/awss-models", filename="layer2bio_00"), |
| "layer2bio_01": hf_hub_download(repo_id="f16sam/awss-models", filename="layer2bio_01"), |
| "layer2non_00": hf_hub_download(repo_id="f16sam/awss-models", filename="layer2non_00"), |
| "layer2non_01": hf_hub_download(repo_id="f16sam/awss-models", filename="layer2non_01"), |
| "layer3_cnn.keras": hf_hub_download(repo_id="f16sam/awss-models", filename="layer3_cnn.keras") |
| } |
|
|
| |
| from models.reconstruct_models import reassemble_chunks |
| reassemble_chunks() |
|
|
| from ModelMain import classify_image |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=[url for url in [fastserver, nodeserver, viteserver, mongoserver] if url], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| @app.post("/classify/") |
| async def classify(file: UploadFile = File(...)): |
| contents = await file.read() |
| result = classify_image(contents) |
| return result |
|
|
| if __name__ == "__main__": |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |