Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import os | |
| import sys | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(BASE_DIR) | |
| from model import SentiTaxEngine, SentiTaxRLM | |
| app = FastAPI( | |
| title="SentiTax RLM API", | |
| description="SML and RLM service for taxation analysis and compliance checking.", | |
| version="2.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| WEIGHTS_PATH = os.path.join(BASE_DIR, "weights", "tax_model.pt") | |
| engine = SentiTaxEngine(WEIGHTS_PATH) | |
| rlm_engine = SentiTaxRLM() | |
| class RequestBody(BaseModel): | |
| text: str | |
| tier: str = "A" | |
| async def root(): | |
| return {"status": "active", "service": "sentitax", "port": 9206} | |
| async def health(): | |
| rlm_health = await rlm_engine.engine.health_check() | |
| return { | |
| "status": "ok", | |
| "rlm_health": rlm_health | |
| } | |
| async def predict_endpoint(body: RequestBody, deep: bool = False): | |
| try: | |
| if deep or body.tier in ("C", "D"): | |
| return await rlm_engine.predict_deep(body.text, body.tier) | |
| return engine.predict(body.text) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def score_endpoint(body: RequestBody, deep: bool = False): | |
| # Support legacy /score path mapped in sml_client.py | |
| try: | |
| if deep or body.tier in ("C", "D"): | |
| return await rlm_engine.predict_deep(body.text, body.tier) | |
| return engine.predict(body.text) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def tax_reason(body: RequestBody): | |
| try: | |
| return await rlm_engine.predict_deep(body.text, body.tier) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def shutdown_event(): | |
| from senti.core.engines.superpacks.rlm_engine import RLMEngine | |
| await RLMEngine.shutdown() | |
| async def predict_endpoint_legacy_alias(body: RequestBody): | |
| try: | |
| # Try different possible engines | |
| if 'engine' in globals(): | |
| return engine.predict(body.text) | |
| elif 'rlm_engine' in globals(): | |
| return await rlm_engine.predict_deep(body.text, "A") | |
| else: | |
| return {"status": "ok", "service": "sentitax"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=9206) | |