Spaces:
Sleeping
Sleeping
Dinuk-Di commited on
Commit ·
a122f91
1
Parent(s): 735e421
Chat Api
Browse files- .gitattributes +2 -0
- Dockerfile +10 -20
- app/main.py +26 -63
- app/model.py +38 -120
- app/routes.py +17 -124
- app/schema.py +9 -107
- app/services.py +0 -247
- requirements.txt +7 -13
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.env filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.env.* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
CHANGED
|
@@ -3,23 +3,18 @@ FROM python:3.11-slim
|
|
| 3 |
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
PYTHONUNBUFFERED=1 \
|
| 5 |
PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
-
HF_HOME=/app/.cache/huggingface \
|
| 7 |
-
TRANSFORMERS_CACHE=/app/.cache/huggingface \
|
| 8 |
PORT=7860
|
| 9 |
|
| 10 |
-
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
-
ffmpeg libsndfile1 git curl \
|
| 12 |
-
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
| 13 |
-
|
| 14 |
WORKDIR /app
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
RUN pip install --no-cache-dir
|
| 22 |
|
|
|
|
| 23 |
COPY requirements.txt .
|
| 24 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 25 |
|
|
@@ -30,13 +25,8 @@ USER appuser
|
|
| 30 |
|
| 31 |
EXPOSE 7860
|
| 32 |
|
| 33 |
-
HEALTHCHECK --interval=60s --timeout=15s --start-period=
|
| 34 |
-
CMD curl -f http://localhost:7860/ || exit 1
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
"--host", "0.0.0.0", \
|
| 39 |
-
"--port", "7860", \
|
| 40 |
-
"--workers", "1", \
|
| 41 |
-
"--loop", "uvloop", \
|
| 42 |
-
"--log-level", "info"]
|
|
|
|
| 3 |
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
PYTHONUNBUFFERED=1 \
|
| 5 |
PYTHONDONTWRITEBYTECODE=1 \
|
|
|
|
|
|
|
| 6 |
PORT=7860
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
curl \
|
| 12 |
+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
| 13 |
|
| 14 |
+
# Upgrade build tools
|
| 15 |
+
RUN pip install --no-cache-dir --upgrade pip setuptools wheel
|
| 16 |
|
| 17 |
+
# Install dependencies (no heavy CUDA packages needed for LangChain HF Endpoint!)
|
| 18 |
COPY requirements.txt .
|
| 19 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
|
|
|
|
| 25 |
|
| 26 |
EXPOSE 7860
|
| 27 |
|
| 28 |
+
HEALTHCHECK --interval=60s --timeout=15s --start-period=30s --retries=3 \
|
| 29 |
+
CMD curl -f http://localhost:7860/api/health || exit 1
|
| 30 |
|
| 31 |
+
# Start uvicorn, ensuring sys.path includes the root so `app.` imports work
|
| 32 |
+
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/main.py
CHANGED
|
@@ -1,60 +1,41 @@
|
|
| 1 |
-
# main.py
|
| 2 |
-
import logging
|
| 3 |
-
import time
|
| 4 |
-
import os
|
| 5 |
-
from contextlib import asynccontextmanager
|
| 6 |
-
|
| 7 |
from fastapi import FastAPI, Request
|
|
|
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
-
|
| 13 |
-
import
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
from routes import router_rag, router_ingest, router_monitor, router_health
|
| 17 |
|
| 18 |
-
logging.basicConfig(
|
| 19 |
-
level=logging.INFO,
|
| 20 |
-
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 21 |
-
)
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
ENABLE_AUDIO = os.getenv("ENABLE_AUDIO_OUTPUT", "false").lower() == "true"
|
| 27 |
-
|
| 28 |
|
| 29 |
@asynccontextmanager
|
| 30 |
async def lifespan(app: FastAPI):
|
| 31 |
print("\n" + "="*60, flush=True)
|
| 32 |
-
print("🚀 INITIALIZING
|
| 33 |
-
print("⏳ This model is ~15GB and will take a few minutes to load.", flush=True)
|
| 34 |
print("="*60 + "\n", flush=True)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
load_model(
|
| 38 |
-
logger.info("Model ready. API is live.")
|
| 39 |
|
| 40 |
print("\n✅ API is LIVE on port 7860! Ready for requests.\n", flush=True)
|
| 41 |
yield
|
| 42 |
-
|
| 43 |
-
|
| 44 |
|
| 45 |
app = FastAPI(
|
| 46 |
-
title="Multimodal
|
| 47 |
-
description=
|
| 48 |
-
"Production-ready RAG API powered by Qwen2.5-Omni-3B. "
|
| 49 |
-
"Supports text, image, audio, and video modalities."
|
| 50 |
-
),
|
| 51 |
version="1.0.0",
|
| 52 |
lifespan=lifespan,
|
| 53 |
-
docs_url="/docs",
|
| 54 |
-
redoc_url="/redoc",
|
| 55 |
)
|
| 56 |
|
| 57 |
-
# ── Middleware ────────────────────────────────────────────────────────────────
|
| 58 |
app.add_middleware(
|
| 59 |
CORSMiddleware,
|
| 60 |
allow_origins=os.getenv("CORS_ORIGINS", "*").split(","),
|
|
@@ -63,44 +44,26 @@ app.add_middleware(
|
|
| 63 |
allow_headers=["*"],
|
| 64 |
)
|
| 65 |
|
| 66 |
-
app.state.limiter = limiter
|
| 67 |
-
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
@app.middleware("http")
|
| 71 |
async def request_logging_middleware(request: Request, call_next):
|
| 72 |
start = time.time()
|
| 73 |
response = await call_next(request)
|
| 74 |
duration_ms = (time.time() - start) * 1000
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
return response
|
| 80 |
|
| 81 |
-
|
| 82 |
-
@app.exception_handler(Exception)
|
| 83 |
-
async def global_exception_handler(request: Request, exc: Exception):
|
| 84 |
-
logger.exception(f"Unhandled exception: {exc}")
|
| 85 |
-
return JSONResponse(
|
| 86 |
-
status_code=500,
|
| 87 |
-
content={"error": "Internal server error", "detail": str(exc)},
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
# ── Routers ───────────────────────────────────────────────────────────────────
|
| 92 |
-
app.include_router(router_health)
|
| 93 |
-
app.include_router(router_rag)
|
| 94 |
-
app.include_router(router_ingest)
|
| 95 |
-
app.include_router(router_monitor)
|
| 96 |
-
|
| 97 |
|
| 98 |
if __name__ == "__main__":
|
|
|
|
| 99 |
uvicorn.run(
|
| 100 |
-
"main:app",
|
| 101 |
host="0.0.0.0",
|
| 102 |
port=7860,
|
| 103 |
-
workers=1,
|
| 104 |
-
loop="uvloop",
|
| 105 |
log_level="info",
|
| 106 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI, Request
|
| 2 |
+
from contextlib import asynccontextmanager
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
from app.routes import router as api_router
|
| 5 |
+
from app.model import load_model
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
|
| 11 |
+
load_dotenv()
|
|
|
|
| 12 |
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
|
|
|
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
+
REPO_ID = os.getenv("MODEL_REPO_ID", "deepseek-ai/DeepSeek-R1")
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
@asynccontextmanager
|
| 19 |
async def lifespan(app: FastAPI):
|
| 20 |
print("\n" + "="*60, flush=True)
|
| 21 |
+
print(f"🚀 INITIALIZING CHAT API: Setup remote LLM ({REPO_ID})", flush=True)
|
|
|
|
| 22 |
print("="*60 + "\n", flush=True)
|
| 23 |
|
| 24 |
+
# Store the LLM in the application state so routes can access it
|
| 25 |
+
app.state.llm = load_model(repo_id=REPO_ID)
|
|
|
|
| 26 |
|
| 27 |
print("\n✅ API is LIVE on port 7860! Ready for requests.\n", flush=True)
|
| 28 |
yield
|
| 29 |
+
print("\n" + "="*60, flush=True)
|
| 30 |
+
print("👋 Shutting down API. Goodbye!", flush=True)
|
| 31 |
|
| 32 |
app = FastAPI(
|
| 33 |
+
title="Multimodal Chat API",
|
| 34 |
+
description="Production-ready Chat API powered by LangChain and HuggingFaceEndpoint.",
|
|
|
|
|
|
|
|
|
|
| 35 |
version="1.0.0",
|
| 36 |
lifespan=lifespan,
|
|
|
|
|
|
|
| 37 |
)
|
| 38 |
|
|
|
|
| 39 |
app.add_middleware(
|
| 40 |
CORSMiddleware,
|
| 41 |
allow_origins=os.getenv("CORS_ORIGINS", "*").split(","),
|
|
|
|
| 44 |
allow_headers=["*"],
|
| 45 |
)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
@app.middleware("http")
|
| 48 |
async def request_logging_middleware(request: Request, call_next):
|
| 49 |
start = time.time()
|
| 50 |
response = await call_next(request)
|
| 51 |
duration_ms = (time.time() - start) * 1000
|
| 52 |
+
if request.url.path != "/api/health":
|
| 53 |
+
logger.info(
|
| 54 |
+
f"{request.method} {request.url.path} "
|
| 55 |
+
f"→ {response.status_code} [{duration_ms:.1f}ms]"
|
| 56 |
+
)
|
| 57 |
return response
|
| 58 |
|
| 59 |
+
app.include_router(api_router, prefix="/api")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
if __name__ == "__main__":
|
| 62 |
+
import uvicorn
|
| 63 |
uvicorn.run(
|
| 64 |
+
"app.main:app",
|
| 65 |
host="0.0.0.0",
|
| 66 |
port=7860,
|
| 67 |
+
workers=1,
|
|
|
|
| 68 |
log_level="info",
|
| 69 |
)
|
app/model.py
CHANGED
|
@@ -1,121 +1,39 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def load_model(enable_audio_output: bool = False):
|
| 20 |
-
global _model, _processor, _model_load_time
|
| 21 |
-
|
| 22 |
-
if _model is not None and _processor is not None:
|
| 23 |
-
return _model, _processor
|
| 24 |
-
|
| 25 |
-
logger.info(f"Loading model: {MODEL_ID}")
|
| 26 |
-
start = time.time()
|
| 27 |
-
|
| 28 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
-
logger.info(f"Using device: {device}")
|
| 30 |
-
|
| 31 |
-
load_kwargs: Dict[str, Any] = {
|
| 32 |
-
# Use float32 on CPU — bfloat16 is poorly supported on CPU
|
| 33 |
-
"torch_dtype": torch.bfloat16 if device == "cuda" else torch.float32,
|
| 34 |
-
"device_map": "auto" if device == "cuda" else "cpu",
|
| 35 |
-
# NO flash_attention_2 — only works with GPU + nvcc
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
_model = Qwen2_5OmniForConditionalGeneration.from_pretrained(MODEL_ID, **load_kwargs)
|
| 39 |
-
|
| 40 |
-
# Always disable talker on CPU — saves ~2GB and talker requires GPU
|
| 41 |
-
_model.disable_talker()
|
| 42 |
-
logger.info("Audio talker disabled (CPU mode — saves memory).")
|
| 43 |
-
|
| 44 |
-
_processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_ID)
|
| 45 |
-
_model_load_time = time.time() - start
|
| 46 |
-
logger.info(f"Model loaded in {_model_load_time:.2f}s on {device}")
|
| 47 |
-
|
| 48 |
-
return _model, _processor
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def get_model() -> Qwen2_5OmniForConditionalGeneration:
|
| 52 |
-
if _model is None:
|
| 53 |
-
raise RuntimeError("Model not loaded. Call load_model() first.")
|
| 54 |
-
return _model
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def get_processor() -> Qwen2_5OmniProcessor:
|
| 58 |
-
if _processor is None:
|
| 59 |
-
raise RuntimeError("Processor not loaded. Call load_model() first.")
|
| 60 |
-
return _processor
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def run_inference(
|
| 64 |
-
conversation: List[Dict],
|
| 65 |
-
return_audio: bool = False,
|
| 66 |
-
speaker: str = "Chelsie",
|
| 67 |
-
max_new_tokens: int = 256,
|
| 68 |
-
temperature: float = 0.7,
|
| 69 |
-
use_audio_in_video: bool = True,
|
| 70 |
-
) -> Tuple[str, Optional[bytes], int, int]:
|
| 71 |
-
model = get_model()
|
| 72 |
-
processor = get_processor()
|
| 73 |
-
|
| 74 |
-
# Force return_audio=False on CPU since talker is disabled
|
| 75 |
-
if not torch.cuda.is_available():
|
| 76 |
-
return_audio = False
|
| 77 |
-
|
| 78 |
-
text_template = processor.apply_chat_template(
|
| 79 |
-
conversation,
|
| 80 |
-
add_generation_prompt=True,
|
| 81 |
-
tokenize=False,
|
| 82 |
)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
generate_kwargs: Dict[str, Any] = {
|
| 105 |
-
"use_audio_in_video": use_audio_in_video,
|
| 106 |
-
"max_new_tokens": max_new_tokens,
|
| 107 |
-
"temperature": temperature,
|
| 108 |
-
"do_sample": temperature > 0,
|
| 109 |
-
"return_audio": False, # Always False — talker disabled on CPU
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
with torch.inference_mode():
|
| 113 |
-
outputs = model.generate(**inputs, **generate_kwargs)
|
| 114 |
-
|
| 115 |
-
completion_tokens = outputs.shape[-1] - prompt_tokens
|
| 116 |
-
decoded = processor.batch_decode(
|
| 117 |
-
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 118 |
-
)
|
| 119 |
-
answer = decoded[0] if decoded else ""
|
| 120 |
-
|
| 121 |
-
return answer, None, prompt_tokens, completion_tokens
|
|
|
|
| 1 |
+
from langchain_core.prompts import PromptTemplate
|
| 2 |
+
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
| 3 |
+
import os
|
| 4 |
+
from app.schema import OutputResponse
|
| 5 |
+
|
| 6 |
+
# We completely remove `getpass` to prevent blocking the Docker container.
|
| 7 |
+
# HuggingFace secrets should be defined in the HF space environment automatically.
|
| 8 |
+
if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
|
| 9 |
+
print("WARNING: HUGGINGFACEHUB_API_TOKEN is not set in the environment. "
|
| 10 |
+
"Set this as a secret in your HuggingFace Space or .env file.")
|
| 11 |
+
|
| 12 |
+
def load_model(repo_id: str, max_length: int = 512, temperature: float = 0.5):
|
| 13 |
+
llm = HuggingFaceEndpoint(
|
| 14 |
+
repo_id=repo_id,
|
| 15 |
+
task="text-generation",
|
| 16 |
+
max_new_tokens=max_length,
|
| 17 |
+
do_sample=temperature > 0,
|
| 18 |
+
temperature=temperature if temperature > 0 else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
)
|
| 20 |
+
return llm
|
| 21 |
+
|
| 22 |
+
def generate_answer(question: str, llm) -> OutputResponse:
|
| 23 |
+
try:
|
| 24 |
+
prompt = PromptTemplate(
|
| 25 |
+
input_variables=["question"],
|
| 26 |
+
template="""
|
| 27 |
+
You are a helpful assistant that provides concise and accurate answers to user questions.
|
| 28 |
+
Question: {question}
|
| 29 |
+
Answer Format:
|
| 30 |
+
Answer: <Your concise answer here>
|
| 31 |
+
Justification: <Why this query is relevant to the user's request>
|
| 32 |
+
"""
|
| 33 |
+
)
|
| 34 |
+
chat_model = ChatHuggingFace(llm=llm)
|
| 35 |
+
structured_llm = chat_model.with_structured_output(OutputResponse)
|
| 36 |
+
result = structured_llm.invoke(prompt.format(question=question))
|
| 37 |
+
return result
|
| 38 |
+
except Exception as e:
|
| 39 |
+
return OutputResponse(answer="Error generating answer", justification=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/routes.py
CHANGED
|
@@ -1,131 +1,24 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
| 3 |
-
import
|
| 4 |
-
from typing import Optional
|
| 5 |
|
| 6 |
-
|
| 7 |
-
from fastapi.responses import JSONResponse
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
IngestRequest, IngestResponse,
|
| 12 |
-
HealthResponse, ErrorResponse, UserStatsResponse,
|
| 13 |
-
)
|
| 14 |
-
from services import (
|
| 15 |
-
process_rag_query, ingest_documents,
|
| 16 |
-
get_global_stats, get_user_stats,
|
| 17 |
-
)
|
| 18 |
-
from model import get_model, get_processor
|
| 19 |
-
|
| 20 |
-
logger = logging.getLogger(__name__)
|
| 21 |
-
|
| 22 |
-
router_rag = APIRouter(prefix="/rag", tags=["RAG"])
|
| 23 |
-
router_ingest = APIRouter(prefix="/ingest", tags=["Ingestion"])
|
| 24 |
-
router_monitor = APIRouter(prefix="/monitor", tags=["Monitoring"])
|
| 25 |
-
router_health = APIRouter(tags=["Health"])
|
| 26 |
-
|
| 27 |
-
_start_time = time.time()
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# ── Auth dependency (replace with JWT/OAuth in production) ───────────────────
|
| 31 |
-
async def verify_api_key(x_api_key: Optional[str] = Header(default=None)):
|
| 32 |
-
import os
|
| 33 |
-
expected = os.getenv("API_KEY", "dev-secret")
|
| 34 |
-
if x_api_key != expected:
|
| 35 |
-
raise HTTPException(status_code=401, detail="Invalid or missing API key.")
|
| 36 |
-
return x_api_key
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# ── RAG Routes ────────────────────────────────────────────────────────────────
|
| 40 |
-
@router_rag.post(
|
| 41 |
-
"/query",
|
| 42 |
-
response_model=RAGQueryResponse,
|
| 43 |
-
responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
|
| 44 |
-
summary="Multimodal RAG Query",
|
| 45 |
-
description="Submit text, image, audio, or video inputs to query the RAG pipeline.",
|
| 46 |
-
)
|
| 47 |
-
async def rag_query(
|
| 48 |
-
request: RAGQueryRequest,
|
| 49 |
-
background_tasks: BackgroundTasks,
|
| 50 |
-
_: str = Depends(verify_api_key),
|
| 51 |
-
):
|
| 52 |
try:
|
| 53 |
-
|
| 54 |
-
return response
|
| 55 |
-
except ValueError as e:
|
| 56 |
-
raise HTTPException(status_code=400, detail=str(e))
|
| 57 |
except Exception as e:
|
| 58 |
-
|
| 59 |
-
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
|
| 60 |
-
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
"/documents",
|
| 65 |
-
response_model=IngestResponse,
|
| 66 |
-
summary="Ingest Multimodal Documents",
|
| 67 |
-
)
|
| 68 |
-
async def ingest_docs(
|
| 69 |
-
request: IngestRequest,
|
| 70 |
-
_: str = Depends(verify_api_key),
|
| 71 |
-
):
|
| 72 |
try:
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
)
|
| 78 |
-
return
|
| 79 |
-
ingested_count=len(doc_ids),
|
| 80 |
-
doc_ids=doc_ids,
|
| 81 |
-
message=f"Successfully ingested {len(doc_ids)} documents.",
|
| 82 |
-
)
|
| 83 |
except Exception as e:
|
| 84 |
-
|
| 85 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# ── Monitoring Routes ─────────────────────────────────────────────────────────
|
| 89 |
-
@router_monitor.get(
|
| 90 |
-
"/health",
|
| 91 |
-
response_model=HealthResponse,
|
| 92 |
-
summary="API Health Check",
|
| 93 |
-
)
|
| 94 |
-
async def health_check():
|
| 95 |
-
try:
|
| 96 |
-
model = get_model()
|
| 97 |
-
device = str(next(model.parameters()).device)
|
| 98 |
-
model_loaded = True
|
| 99 |
-
except RuntimeError:
|
| 100 |
-
device = "unavailable"
|
| 101 |
-
model_loaded = False
|
| 102 |
-
|
| 103 |
-
stats = get_global_stats()
|
| 104 |
-
return HealthResponse(
|
| 105 |
-
status="ok" if model_loaded else "degraded",
|
| 106 |
-
model_loaded=model_loaded,
|
| 107 |
-
device=device,
|
| 108 |
-
uptime_seconds=time.time() - _start_time,
|
| 109 |
-
total_requests=stats["total_requests"],
|
| 110 |
-
total_tokens_processed=stats["total_tokens"],
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
@router_monitor.get(
|
| 115 |
-
"/users/{user_id}",
|
| 116 |
-
response_model=UserStatsResponse,
|
| 117 |
-
summary="Get Per-User Usage Stats",
|
| 118 |
-
)
|
| 119 |
-
async def user_stats(
|
| 120 |
-
user_id: str,
|
| 121 |
-
_: str = Depends(verify_api_key),
|
| 122 |
-
):
|
| 123 |
-
stats = get_user_stats(user_id)
|
| 124 |
-
if stats is None:
|
| 125 |
-
raise HTTPException(status_code=404, detail=f"User '{user_id}' not found.")
|
| 126 |
-
return UserStatsResponse(**stats)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
@router_health.get("/", include_in_schema=False)
|
| 130 |
-
async def root():
|
| 131 |
-
return {"message": "Multimodal RAG API is running. Visit /docs for API reference."}
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Request
|
| 2 |
+
from app.model import generate_answer
|
| 3 |
+
from app.schema import UserRequest
|
|
|
|
| 4 |
|
| 5 |
+
router = APIRouter()
|
|
|
|
| 6 |
|
| 7 |
+
@router.get("/health", tags=["Health"])
|
| 8 |
+
async def health_check():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
try:
|
| 10 |
+
return {"status": "ok"}
|
|
|
|
|
|
|
|
|
|
| 11 |
except Exception as e:
|
| 12 |
+
return {"status": "error", "message": str(e)}
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
@router.post("/chat", tags=["Chat"])
|
| 15 |
+
async def chat_endpoint(request_body: UserRequest, request: Request):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
try:
|
| 17 |
+
llm = request.app.state.llm
|
| 18 |
+
if not llm:
|
| 19 |
+
return {"status": "error", "message": "LLM not loaded into application state."}
|
| 20 |
+
|
| 21 |
+
response = generate_answer(request_body.question, llm)
|
| 22 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
except Exception as e:
|
| 24 |
+
return {"status": "error", "message": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/schema.py
CHANGED
|
@@ -1,107 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
TEXT = "text"
|
| 11 |
-
IMAGE = "image"
|
| 12 |
-
AUDIO = "audio"
|
| 13 |
-
VIDEO = "video"
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class MediaInput(BaseModel):
|
| 17 |
-
modality: ModalityType
|
| 18 |
-
content: str = Field(..., description="URL, base64 string, or raw text depending on modality")
|
| 19 |
-
use_audio_in_video: Optional[bool] = Field(default=True, description="Use embedded audio in video")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class RAGQueryRequest(BaseModel):
|
| 23 |
-
query_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 24 |
-
user_id: str = Field(..., description="Unique user identifier")
|
| 25 |
-
query_text: Optional[str] = Field(default=None, description="Natural language query")
|
| 26 |
-
media_inputs: Optional[List[MediaInput]] = Field(default=[], description="List of multimodal inputs")
|
| 27 |
-
top_k: int = Field(default=5, ge=1, le=20, description="Number of RAG context chunks to retrieve")
|
| 28 |
-
return_audio: bool = Field(
|
| 29 |
-
default=False,
|
| 30 |
-
description="Audio output (GPU only — disabled on CPU deployments)")
|
| 31 |
-
speaker: Literal["Chelsie", "Ethan"] = Field(default="Chelsie")
|
| 32 |
-
max_new_tokens: int = Field(default=256, ge=32, le=512)
|
| 33 |
-
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
| 34 |
-
|
| 35 |
-
@validator("media_inputs", always=True)
|
| 36 |
-
def validate_at_least_one_input(cls, v, values):
|
| 37 |
-
if not v and not values.get("query_text"):
|
| 38 |
-
raise ValueError("At least one of query_text or media_inputs must be provided.")
|
| 39 |
-
return v
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class RAGDocument(BaseModel):
|
| 43 |
-
doc_id: str
|
| 44 |
-
content: str
|
| 45 |
-
modality: ModalityType
|
| 46 |
-
score: float = Field(..., ge=0.0, le=1.0)
|
| 47 |
-
metadata: Optional[Dict[str, Any]] = {}
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class TokenUsage(BaseModel):
|
| 51 |
-
prompt_tokens: int
|
| 52 |
-
completion_tokens: int
|
| 53 |
-
total_tokens: int
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class PerformanceMetrics(BaseModel):
|
| 57 |
-
latency_ms: float
|
| 58 |
-
retrieval_latency_ms: float
|
| 59 |
-
generation_latency_ms: float
|
| 60 |
-
throughput_tokens_per_sec: float
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class RAGQueryResponse(BaseModel):
|
| 64 |
-
query_id: str
|
| 65 |
-
user_id: str
|
| 66 |
-
answer_text: str
|
| 67 |
-
retrieved_documents: List[RAGDocument]
|
| 68 |
-
audio_base64: Optional[str] = None
|
| 69 |
-
token_usage: TokenUsage
|
| 70 |
-
performance: PerformanceMetrics
|
| 71 |
-
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
class IngestRequest(BaseModel):
|
| 75 |
-
user_id: str
|
| 76 |
-
documents: List[MediaInput]
|
| 77 |
-
doc_ids: Optional[List[str]] = None
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class IngestResponse(BaseModel):
|
| 81 |
-
ingested_count: int
|
| 82 |
-
doc_ids: List[str]
|
| 83 |
-
message: str
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
class HealthResponse(BaseModel):
|
| 87 |
-
status: str
|
| 88 |
-
model_loaded: bool
|
| 89 |
-
device: str
|
| 90 |
-
uptime_seconds: float
|
| 91 |
-
total_requests: int
|
| 92 |
-
total_tokens_processed: int
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
class ErrorResponse(BaseModel):
|
| 96 |
-
error: str
|
| 97 |
-
detail: Optional[str] = None
|
| 98 |
-
query_id: Optional[str] = None
|
| 99 |
-
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
class UserStatsResponse(BaseModel):
|
| 103 |
-
user_id: str
|
| 104 |
-
total_queries: int
|
| 105 |
-
total_tokens: int
|
| 106 |
-
avg_latency_ms: float
|
| 107 |
-
last_active: Optional[datetime] = None
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
|
| 3 |
+
class UserRequest(BaseModel):
|
| 4 |
+
question: str = Field(..., description="The user's question or request.")
|
| 5 |
+
class OutputResponse(BaseModel):
|
| 6 |
+
answer: str = Field(..., description="The answer generated by the model.")
|
| 7 |
+
justification: str = Field(
|
| 8 |
+
..., description="Why this query is relevant to the user's request."
|
| 9 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/services.py
DELETED
|
@@ -1,247 +0,0 @@
|
|
| 1 |
-
# services.py
|
| 2 |
-
import uuid
|
| 3 |
-
import time
|
| 4 |
-
import base64
|
| 5 |
-
import logging
|
| 6 |
-
import asyncio
|
| 7 |
-
from datetime import datetime
|
| 8 |
-
from collections import defaultdict
|
| 9 |
-
from typing import List, Dict, Optional, Tuple, Any
|
| 10 |
-
|
| 11 |
-
import numpy as np
|
| 12 |
-
from sentence_transformers import SentenceTransformer
|
| 13 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 14 |
-
|
| 15 |
-
from schema import (
|
| 16 |
-
MediaInput, RAGDocument, TokenUsage, PerformanceMetrics,
|
| 17 |
-
RAGQueryRequest, RAGQueryResponse, ModalityType
|
| 18 |
-
)
|
| 19 |
-
from model import run_inference
|
| 20 |
-
|
| 21 |
-
logger = logging.getLogger(__name__)
|
| 22 |
-
|
| 23 |
-
# ── In-memory vector store (swap for FAISS/Qdrant/Chroma in production) ──────
|
| 24 |
-
_doc_store: Dict[str, Dict[str, Any]] = {}
|
| 25 |
-
_embeddings_store: Dict[str, np.ndarray] = {}
|
| 26 |
-
_embed_model: Optional[SentenceTransformer] = None
|
| 27 |
-
|
| 28 |
-
# ── Monitoring state ──────────────────────────────────────────────────────────
|
| 29 |
-
_global_stats = {
|
| 30 |
-
"total_requests": 0,
|
| 31 |
-
"total_tokens": 0,
|
| 32 |
-
"start_time": time.time(),
|
| 33 |
-
}
|
| 34 |
-
_user_stats: Dict[str, Dict] = defaultdict(lambda: {
|
| 35 |
-
"total_queries": 0,
|
| 36 |
-
"total_tokens": 0,
|
| 37 |
-
"latencies": [],
|
| 38 |
-
"last_active": None,
|
| 39 |
-
})
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def get_embed_model() -> SentenceTransformer:
|
| 43 |
-
global _embed_model
|
| 44 |
-
if _embed_model is None:
|
| 45 |
-
_embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 46 |
-
logger.info("Embedding model loaded.")
|
| 47 |
-
return _embed_model
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _embed_text(text: str) -> np.ndarray:
|
| 51 |
-
model = get_embed_model()
|
| 52 |
-
return model.encode([text], normalize_embeddings=True)[0]
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def ingest_documents(
|
| 56 |
-
user_id: str,
|
| 57 |
-
documents: List[MediaInput],
|
| 58 |
-
doc_ids: Optional[List[str]] = None,
|
| 59 |
-
) -> List[str]:
|
| 60 |
-
ids = []
|
| 61 |
-
for i, doc in enumerate(documents):
|
| 62 |
-
doc_id = doc_ids[i] if doc_ids and i < len(doc_ids) else str(uuid.uuid4())
|
| 63 |
-
|
| 64 |
-
# For non-text modalities, embed a descriptor; full multimodal embeddings
|
| 65 |
-
# require a separate vision-language embedding model (e.g., CLIP, ImageBind).
|
| 66 |
-
if doc.modality == ModalityType.TEXT:
|
| 67 |
-
embed_text = doc.content
|
| 68 |
-
else:
|
| 69 |
-
embed_text = f"[{doc.modality.value.upper()} resource] {doc.content}"
|
| 70 |
-
|
| 71 |
-
embedding = _embed_text(embed_text)
|
| 72 |
-
_doc_store[doc_id] = {
|
| 73 |
-
"doc_id": doc_id,
|
| 74 |
-
"user_id": user_id,
|
| 75 |
-
"content": doc.content,
|
| 76 |
-
"modality": doc.modality,
|
| 77 |
-
"metadata": {"ingested_at": datetime.utcnow().isoformat()},
|
| 78 |
-
}
|
| 79 |
-
_embeddings_store[doc_id] = embedding
|
| 80 |
-
ids.append(doc_id)
|
| 81 |
-
|
| 82 |
-
logger.info(f"Ingested {len(ids)} documents for user {user_id}")
|
| 83 |
-
return ids
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def retrieve_documents(
|
| 87 |
-
query_text: str,
|
| 88 |
-
top_k: int = 5,
|
| 89 |
-
) -> List[RAGDocument]:
|
| 90 |
-
if not _embeddings_store:
|
| 91 |
-
return []
|
| 92 |
-
|
| 93 |
-
query_emb = _embed_text(query_text).reshape(1, -1)
|
| 94 |
-
doc_ids = list(_embeddings_store.keys())
|
| 95 |
-
doc_embs = np.vstack([_embeddings_store[d] for d in doc_ids])
|
| 96 |
-
scores = cosine_similarity(query_emb, doc_embs)[0]
|
| 97 |
-
|
| 98 |
-
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 99 |
-
results = []
|
| 100 |
-
for idx in top_indices:
|
| 101 |
-
did = doc_ids[idx]
|
| 102 |
-
doc = _doc_store[did]
|
| 103 |
-
results.append(RAGDocument(
|
| 104 |
-
doc_id=did,
|
| 105 |
-
content=doc["content"],
|
| 106 |
-
modality=doc["modality"],
|
| 107 |
-
score=float(scores[idx]),
|
| 108 |
-
metadata=doc.get("metadata", {}),
|
| 109 |
-
))
|
| 110 |
-
return results
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def _build_rag_conversation(
|
| 114 |
-
request: RAGQueryRequest,
|
| 115 |
-
retrieved_docs: List[RAGDocument],
|
| 116 |
-
) -> List[Dict]:
|
| 117 |
-
system_prompt = (
|
| 118 |
-
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
|
| 119 |
-
"capable of perceiving auditory and visual inputs, as well as generating text and speech."
|
| 120 |
-
if request.return_audio
|
| 121 |
-
else "You are a helpful multimodal AI assistant with access to retrieved context."
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
context_str = "\n\n".join(
|
| 125 |
-
[f"[Context {i+1} | {d.modality.value}]: {d.content}" for i, d in enumerate(retrieved_docs)]
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
user_content: List[Dict] = []
|
| 129 |
-
|
| 130 |
-
for media in (request.media_inputs or []):
|
| 131 |
-
if media.modality == ModalityType.TEXT:
|
| 132 |
-
user_content.append({"type": "text", "text": media.content})
|
| 133 |
-
elif media.modality == ModalityType.IMAGE:
|
| 134 |
-
user_content.append({"type": "image", "image": media.content})
|
| 135 |
-
elif media.modality == ModalityType.AUDIO:
|
| 136 |
-
user_content.append({"type": "audio", "audio": media.content})
|
| 137 |
-
elif media.modality == ModalityType.VIDEO:
|
| 138 |
-
user_content.append({"type": "video", "video": media.content})
|
| 139 |
-
|
| 140 |
-
final_query = (
|
| 141 |
-
f"Retrieved Context:\n{context_str}\n\n"
|
| 142 |
-
f"User Query: {request.query_text or 'Analyze the provided media.'}"
|
| 143 |
-
)
|
| 144 |
-
user_content.append({"type": "text", "text": final_query})
|
| 145 |
-
|
| 146 |
-
return [
|
| 147 |
-
{"role": "system", "content": [{"type": "text", "text": system_prompt}]},
|
| 148 |
-
{"role": "user", "content": user_content},
|
| 149 |
-
]
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
def _update_monitoring(
|
| 153 |
-
user_id: str,
|
| 154 |
-
total_tokens: int,
|
| 155 |
-
latency_ms: float,
|
| 156 |
-
):
|
| 157 |
-
_global_stats["total_requests"] += 1
|
| 158 |
-
_global_stats["total_tokens"] += total_tokens
|
| 159 |
-
|
| 160 |
-
user = _user_stats[user_id]
|
| 161 |
-
user["total_queries"] += 1
|
| 162 |
-
user["total_tokens"] += total_tokens
|
| 163 |
-
user["latencies"].append(latency_ms)
|
| 164 |
-
# Keep only last 1000 latencies per user to avoid unbounded memory
|
| 165 |
-
if len(user["latencies"]) > 1000:
|
| 166 |
-
user["latencies"] = user["latencies"][-1000:]
|
| 167 |
-
user["last_active"] = datetime.utcnow()
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
async def process_rag_query(request: RAGQueryRequest) -> RAGQueryResponse:
|
| 171 |
-
total_start = time.time()
|
| 172 |
-
|
| 173 |
-
# ── Retrieval ─────────────────────────────────────────────────────────────
|
| 174 |
-
retrieval_start = time.time()
|
| 175 |
-
query_for_retrieval = request.query_text or " ".join(
|
| 176 |
-
m.content for m in (request.media_inputs or []) if m.modality == ModalityType.TEXT
|
| 177 |
-
) or "multimodal query"
|
| 178 |
-
retrieved_docs = retrieve_documents(query_for_retrieval, top_k=request.top_k)
|
| 179 |
-
retrieval_latency_ms = (time.time() - retrieval_start) * 1000
|
| 180 |
-
|
| 181 |
-
# ── Build conversation ────────────────────────────────────────────────────
|
| 182 |
-
conversation = _build_rag_conversation(request, retrieved_docs)
|
| 183 |
-
|
| 184 |
-
use_audio_in_video = any(
|
| 185 |
-
m.use_audio_in_video for m in (request.media_inputs or [])
|
| 186 |
-
if m.modality == ModalityType.VIDEO
|
| 187 |
-
)
|
| 188 |
-
|
| 189 |
-
# ── Generation (run in thread pool to avoid blocking event loop) ──────────
|
| 190 |
-
gen_start = time.time()
|
| 191 |
-
loop = asyncio.get_event_loop()
|
| 192 |
-
answer, audio_bytes, prompt_tokens, completion_tokens = await loop.run_in_executor(
|
| 193 |
-
None,
|
| 194 |
-
lambda: run_inference(
|
| 195 |
-
conversation=conversation,
|
| 196 |
-
return_audio=request.return_audio,
|
| 197 |
-
speaker=request.speaker,
|
| 198 |
-
max_new_tokens=request.max_new_tokens,
|
| 199 |
-
temperature=request.temperature,
|
| 200 |
-
use_audio_in_video=use_audio_in_video,
|
| 201 |
-
),
|
| 202 |
-
)
|
| 203 |
-
gen_latency_ms = (time.time() - gen_start) * 1000
|
| 204 |
-
|
| 205 |
-
total_latency_ms = (time.time() - total_start) * 1000
|
| 206 |
-
total_tokens = prompt_tokens + completion_tokens
|
| 207 |
-
throughput = (completion_tokens / (gen_latency_ms / 1000)) if gen_latency_ms > 0 else 0
|
| 208 |
-
|
| 209 |
-
_update_monitoring(request.user_id, total_tokens, total_latency_ms)
|
| 210 |
-
|
| 211 |
-
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") if audio_bytes else None
|
| 212 |
-
|
| 213 |
-
return RAGQueryResponse(
|
| 214 |
-
query_id=request.query_id,
|
| 215 |
-
user_id=request.user_id,
|
| 216 |
-
answer_text=answer,
|
| 217 |
-
retrieved_documents=retrieved_docs,
|
| 218 |
-
audio_base64=audio_b64,
|
| 219 |
-
token_usage=TokenUsage(
|
| 220 |
-
prompt_tokens=prompt_tokens,
|
| 221 |
-
completion_tokens=completion_tokens,
|
| 222 |
-
total_tokens=total_tokens,
|
| 223 |
-
),
|
| 224 |
-
performance=PerformanceMetrics(
|
| 225 |
-
latency_ms=total_latency_ms,
|
| 226 |
-
retrieval_latency_ms=retrieval_latency_ms,
|
| 227 |
-
generation_latency_ms=gen_latency_ms,
|
| 228 |
-
throughput_tokens_per_sec=throughput,
|
| 229 |
-
),
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
def get_global_stats() -> Dict:
|
| 234 |
-
return _global_stats.copy()
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
def get_user_stats(user_id: str) -> Optional[Dict]:
|
| 238 |
-
if user_id not in _user_stats:
|
| 239 |
-
return None
|
| 240 |
-
u = _user_stats[user_id]
|
| 241 |
-
return {
|
| 242 |
-
"user_id": user_id,
|
| 243 |
-
"total_queries": u["total_queries"],
|
| 244 |
-
"total_tokens": u["total_tokens"],
|
| 245 |
-
"avg_latency_ms": float(np.mean(u["latencies"])) if u["latencies"] else 0.0,
|
| 246 |
-
"last_active": u["last_active"],
|
| 247 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,13 +1,7 @@
|
|
| 1 |
-
fastapi
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
sentence-transformers>=3.0.0
|
| 9 |
-
scikit-learn>=1.4.0
|
| 10 |
-
soundfile>=0.12.1
|
| 11 |
-
numpy>=1.26.0
|
| 12 |
-
slowapi>=0.1.9
|
| 13 |
-
# flash-attn REMOVED — requires nvcc/GPU to compile
|
|
|
|
| 1 |
+
fastapi[standard]
|
| 2 |
+
langchain_core
|
| 3 |
+
langgraph
|
| 4 |
+
huggingface_hub
|
| 5 |
+
langchain
|
| 6 |
+
langchain-huggingface
|
| 7 |
+
pydantic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|