Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| from src.text_embedding import TextEmbeddingModel | |
| from src.index import Indexer | |
| import os | |
| import pickle | |
| from infer import infer_3_class, infer_model_specific | |
| import uvicorn | |
| from datasets import disable_caching | |
| disable_caching() | |
| app = FastAPI() | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class Opt: | |
| def __init__(self): | |
| self.model_name = "ngocminhta/faid-v1" | |
| self.database_path = "core/seen_db" | |
| self.embedding_dim = 768 | |
| self.device_num = 1 | |
| opt = Opt() | |
| def load_pkl(path): | |
| with open(path, 'rb') as f: | |
| return pickle.load(f) | |
| def load_model_resources(): | |
| global model, tokenizer, index, label_dict, is_mixed_dict, write_model_dict | |
| model = TextEmbeddingModel(opt.model_name) | |
| tokenizer=model.tokenizer | |
| index = Indexer(opt.embedding_dim) | |
| index.deserialize_from(opt.database_path) | |
| label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl')) | |
| is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl')) | |
| write_model_dict=load_pkl(os.path.join(opt.database_path,'write_model_dict.pkl')) | |
| async def predict(request: Request): | |
| data = await request.json() | |
| mode = data.get("mode", "normal").lower() | |
| text_list = data.get("text", []) | |
| if mode == "normal": | |
| results = infer_3_class(model=model, | |
| tokenizer=tokenizer, | |
| index=index, | |
| label_dict=label_dict, | |
| is_mixed_dict=is_mixed_dict, | |
| text_list=text_list, | |
| K=21) | |
| return JSONResponse(content={"results": results}) | |
| elif mode == "advanced": | |
| results = infer_model_specific(model=model, | |
| tokenizer=tokenizer, | |
| index=index, | |
| label_dict=label_dict, | |
| is_mixed_dict=is_mixed_dict, | |
| write_model_dict=write_model_dict, | |
| text_list=text_list, | |
| K=21, | |
| K_model=9) | |
| return JSONResponse(content={"results": results}) | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
| def index() -> FileResponse: | |
| return FileResponse(path="/app/static/index.html", media_type="text/html") | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |