Dinuk-Di commited on
Commit
a122f91
·
1 Parent(s): 735e421
Files changed (8) hide show
  1. .gitattributes +2 -0
  2. Dockerfile +10 -20
  3. app/main.py +26 -63
  4. app/model.py +38 -120
  5. app/routes.py +17 -124
  6. app/schema.py +9 -107
  7. app/services.py +0 -247
  8. requirements.txt +7 -13
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.env filter=lfs diff=lfs merge=lfs -text
37
+ *.env.* filter=lfs diff=lfs merge=lfs -text
Dockerfile CHANGED
@@ -3,23 +3,18 @@ FROM python:3.11-slim
3
  ENV DEBIAN_FRONTEND=noninteractive \
4
  PYTHONUNBUFFERED=1 \
5
  PYTHONDONTWRITEBYTECODE=1 \
6
- HF_HOME=/app/.cache/huggingface \
7
- TRANSFORMERS_CACHE=/app/.cache/huggingface \
8
  PORT=7860
9
 
10
- RUN apt-get update && apt-get install -y --no-install-recommends \
11
- ffmpeg libsndfile1 git curl \
12
- && apt-get clean && rm -rf /var/lib/apt/lists/*
13
-
14
  WORKDIR /app
15
 
16
- # Upgrade build tools first
17
- RUN pip install --no-cache-dir --upgrade pip setuptools wheel packaging hf_transfer
18
- ENV HF_HUB_ENABLE_HF_TRANSFER=1
19
 
20
- # Default PyTorch with CUDA support (crucial for A10G inference to avoid CPU OOM)
21
- RUN pip install --no-cache-dir torch torchvision torchaudio
22
 
 
23
  COPY requirements.txt .
24
  RUN pip install --no-cache-dir -r requirements.txt
25
 
@@ -30,13 +25,8 @@ USER appuser
30
 
31
  EXPOSE 7860
32
 
33
- HEALTHCHECK --interval=60s --timeout=15s --start-period=300s --retries=3 \
34
- CMD curl -f http://localhost:7860/ || exit 1
35
 
36
- CMD ["python", "-m", "uvicorn", "main:app", \
37
- "--app-dir", "app", \
38
- "--host", "0.0.0.0", \
39
- "--port", "7860", \
40
- "--workers", "1", \
41
- "--loop", "uvloop", \
42
- "--log-level", "info"]
 
3
  ENV DEBIAN_FRONTEND=noninteractive \
4
  PYTHONUNBUFFERED=1 \
5
  PYTHONDONTWRITEBYTECODE=1 \
 
 
6
  PORT=7860
7
 
 
 
 
 
8
  WORKDIR /app
9
 
10
+ RUN apt-get update && apt-get install -y --no-install-recommends \
11
+ curl \
12
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
13
 
14
+ # Upgrade build tools
15
+ RUN pip install --no-cache-dir --upgrade pip setuptools wheel
16
 
17
+ # Install dependencies (no heavy CUDA packages needed for LangChain HF Endpoint!)
18
  COPY requirements.txt .
19
  RUN pip install --no-cache-dir -r requirements.txt
20
 
 
25
 
26
  EXPOSE 7860
27
 
28
+ HEALTHCHECK --interval=60s --timeout=15s --start-period=30s --retries=3 \
29
+ CMD curl -f http://localhost:7860/api/health || exit 1
30
 
31
+ # Start uvicorn, ensuring sys.path includes the root so `app.` imports work
32
+ CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
 
 
 
 
 
app/main.py CHANGED
@@ -1,60 +1,41 @@
1
- # main.py
2
- import logging
3
- import time
4
- import os
5
- from contextlib import asynccontextmanager
6
-
7
  from fastapi import FastAPI, Request
 
8
  from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import JSONResponse
10
- from slowapi import Limiter, _rate_limit_exceeded_handler
11
- from slowapi.util import get_remote_address
12
- from slowapi.errors import RateLimitExceeded
13
- import uvicorn
 
14
 
15
- from model import load_model
16
- from routes import router_rag, router_ingest, router_monitor, router_health
17
 
18
- logging.basicConfig(
19
- level=logging.INFO,
20
- format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
21
- )
22
  logger = logging.getLogger(__name__)
23
 
24
- limiter = Limiter(key_func=get_remote_address)
25
-
26
- ENABLE_AUDIO = os.getenv("ENABLE_AUDIO_OUTPUT", "false").lower() == "true"
27
-
28
 
29
  @asynccontextmanager
30
  async def lifespan(app: FastAPI):
31
  print("\n" + "="*60, flush=True)
32
- print("🚀 INITIALIZING APP: Downloading and loading Qwen2.5-Omni-3B", flush=True)
33
- print("⏳ This model is ~15GB and will take a few minutes to load.", flush=True)
34
  print("="*60 + "\n", flush=True)
35
 
36
- logger.info("Starting up loading Qwen2.5-Omni-3B model...")
37
- load_model(enable_audio_output=ENABLE_AUDIO)
38
- logger.info("Model ready. API is live.")
39
 
40
  print("\n✅ API is LIVE on port 7860! Ready for requests.\n", flush=True)
41
  yield
42
- logger.info("Shutting down API.")
43
-
44
 
45
  app = FastAPI(
46
- title="Multimodal RAG API",
47
- description=(
48
- "Production-ready RAG API powered by Qwen2.5-Omni-3B. "
49
- "Supports text, image, audio, and video modalities."
50
- ),
51
  version="1.0.0",
52
  lifespan=lifespan,
53
- docs_url="/docs",
54
- redoc_url="/redoc",
55
  )
56
 
57
- # ── Middleware ────────────────────────────────────────────────────────────────
58
  app.add_middleware(
59
  CORSMiddleware,
60
  allow_origins=os.getenv("CORS_ORIGINS", "*").split(","),
@@ -63,44 +44,26 @@ app.add_middleware(
63
  allow_headers=["*"],
64
  )
65
 
66
- app.state.limiter = limiter
67
- app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
68
-
69
-
70
  @app.middleware("http")
71
  async def request_logging_middleware(request: Request, call_next):
72
  start = time.time()
73
  response = await call_next(request)
74
  duration_ms = (time.time() - start) * 1000
75
- logger.info(
76
- f"{request.method} {request.url.path} "
77
- f"{response.status_code} [{duration_ms:.1f}ms]"
78
- )
 
79
  return response
80
 
81
-
82
- @app.exception_handler(Exception)
83
- async def global_exception_handler(request: Request, exc: Exception):
84
- logger.exception(f"Unhandled exception: {exc}")
85
- return JSONResponse(
86
- status_code=500,
87
- content={"error": "Internal server error", "detail": str(exc)},
88
- )
89
-
90
-
91
- # ── Routers ───────────────────────────────────────────────────────────────────
92
- app.include_router(router_health)
93
- app.include_router(router_rag)
94
- app.include_router(router_ingest)
95
- app.include_router(router_monitor)
96
-
97
 
98
  if __name__ == "__main__":
 
99
  uvicorn.run(
100
- "main:app",
101
  host="0.0.0.0",
102
  port=7860,
103
- workers=1, # Single worker for GPU models
104
- loop="uvloop",
105
  log_level="info",
106
  )
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Request
2
+ from contextlib import asynccontextmanager
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from app.routes import router as api_router
5
+ from app.model import load_model
6
+ from dotenv import load_dotenv
7
+ import os
8
+ import time
9
+ import logging
10
 
11
+ load_dotenv()
 
12
 
13
+ logging.basicConfig(level=logging.INFO)
 
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
+ REPO_ID = os.getenv("MODEL_REPO_ID", "deepseek-ai/DeepSeek-R1")
 
 
 
17
 
18
  @asynccontextmanager
19
  async def lifespan(app: FastAPI):
20
  print("\n" + "="*60, flush=True)
21
+ print(f"🚀 INITIALIZING CHAT API: Setup remote LLM ({REPO_ID})", flush=True)
 
22
  print("="*60 + "\n", flush=True)
23
 
24
+ # Store the LLM in the application state so routes can access it
25
+ app.state.llm = load_model(repo_id=REPO_ID)
 
26
 
27
  print("\n✅ API is LIVE on port 7860! Ready for requests.\n", flush=True)
28
  yield
29
+ print("\n" + "="*60, flush=True)
30
+ print("👋 Shutting down API. Goodbye!", flush=True)
31
 
32
  app = FastAPI(
33
+ title="Multimodal Chat API",
34
+ description="Production-ready Chat API powered by LangChain and HuggingFaceEndpoint.",
 
 
 
35
  version="1.0.0",
36
  lifespan=lifespan,
 
 
37
  )
38
 
 
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=os.getenv("CORS_ORIGINS", "*").split(","),
 
44
  allow_headers=["*"],
45
  )
46
 
 
 
 
 
47
  @app.middleware("http")
48
  async def request_logging_middleware(request: Request, call_next):
49
  start = time.time()
50
  response = await call_next(request)
51
  duration_ms = (time.time() - start) * 1000
52
+ if request.url.path != "/api/health":
53
+ logger.info(
54
+ f"{request.method} {request.url.path} "
55
+ f"→ {response.status_code} [{duration_ms:.1f}ms]"
56
+ )
57
  return response
58
 
59
+ app.include_router(api_router, prefix="/api")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  if __name__ == "__main__":
62
+ import uvicorn
63
  uvicorn.run(
64
+ "app.main:app",
65
  host="0.0.0.0",
66
  port=7860,
67
+ workers=1,
 
68
  log_level="info",
69
  )
app/model.py CHANGED
@@ -1,121 +1,39 @@
1
- # model.py
2
- import torch
3
- import logging
4
- import time
5
- from typing import Optional, Tuple, List, Dict, Any
6
-
7
- from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
8
- from qwen_omni_utils import process_mm_info
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- MODEL_ID = "Qwen/Qwen2.5-Omni-3B"
13
-
14
- _model: Optional[Qwen2_5OmniForConditionalGeneration] = None
15
- _processor: Optional[Qwen2_5OmniProcessor] = None
16
- _model_load_time: float = 0.0
17
-
18
-
19
- def load_model(enable_audio_output: bool = False):
20
- global _model, _processor, _model_load_time
21
-
22
- if _model is not None and _processor is not None:
23
- return _model, _processor
24
-
25
- logger.info(f"Loading model: {MODEL_ID}")
26
- start = time.time()
27
-
28
- device = "cuda" if torch.cuda.is_available() else "cpu"
29
- logger.info(f"Using device: {device}")
30
-
31
- load_kwargs: Dict[str, Any] = {
32
- # Use float32 on CPU — bfloat16 is poorly supported on CPU
33
- "torch_dtype": torch.bfloat16 if device == "cuda" else torch.float32,
34
- "device_map": "auto" if device == "cuda" else "cpu",
35
- # NO flash_attention_2 — only works with GPU + nvcc
36
- }
37
-
38
- _model = Qwen2_5OmniForConditionalGeneration.from_pretrained(MODEL_ID, **load_kwargs)
39
-
40
- # Always disable talker on CPU — saves ~2GB and talker requires GPU
41
- _model.disable_talker()
42
- logger.info("Audio talker disabled (CPU mode — saves memory).")
43
-
44
- _processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_ID)
45
- _model_load_time = time.time() - start
46
- logger.info(f"Model loaded in {_model_load_time:.2f}s on {device}")
47
-
48
- return _model, _processor
49
-
50
-
51
- def get_model() -> Qwen2_5OmniForConditionalGeneration:
52
- if _model is None:
53
- raise RuntimeError("Model not loaded. Call load_model() first.")
54
- return _model
55
-
56
-
57
- def get_processor() -> Qwen2_5OmniProcessor:
58
- if _processor is None:
59
- raise RuntimeError("Processor not loaded. Call load_model() first.")
60
- return _processor
61
-
62
-
63
- def run_inference(
64
- conversation: List[Dict],
65
- return_audio: bool = False,
66
- speaker: str = "Chelsie",
67
- max_new_tokens: int = 256,
68
- temperature: float = 0.7,
69
- use_audio_in_video: bool = True,
70
- ) -> Tuple[str, Optional[bytes], int, int]:
71
- model = get_model()
72
- processor = get_processor()
73
-
74
- # Force return_audio=False on CPU since talker is disabled
75
- if not torch.cuda.is_available():
76
- return_audio = False
77
-
78
- text_template = processor.apply_chat_template(
79
- conversation,
80
- add_generation_prompt=True,
81
- tokenize=False,
82
  )
83
- audios, images, videos = process_mm_info(
84
- conversation, use_audio_in_video=use_audio_in_video
85
- )
86
-
87
- inputs = processor(
88
- text=text_template,
89
- audio=audios,
90
- images=images,
91
- videos=videos,
92
- return_tensors="pt",
93
- padding=True,
94
- use_audio_in_video=use_audio_in_video,
95
- ).to(model.device)
96
-
97
- # Match dtype for CPU (float32)
98
- if not torch.cuda.is_available():
99
- inputs = {k: v.float() if v.dtype == torch.float16 else v
100
- for k, v in inputs.items()}
101
-
102
- prompt_tokens = inputs["input_ids"].shape[-1]
103
-
104
- generate_kwargs: Dict[str, Any] = {
105
- "use_audio_in_video": use_audio_in_video,
106
- "max_new_tokens": max_new_tokens,
107
- "temperature": temperature,
108
- "do_sample": temperature > 0,
109
- "return_audio": False, # Always False — talker disabled on CPU
110
- }
111
-
112
- with torch.inference_mode():
113
- outputs = model.generate(**inputs, **generate_kwargs)
114
-
115
- completion_tokens = outputs.shape[-1] - prompt_tokens
116
- decoded = processor.batch_decode(
117
- outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
118
- )
119
- answer = decoded[0] if decoded else ""
120
-
121
- return answer, None, prompt_tokens, completion_tokens
 
1
+ from langchain_core.prompts import PromptTemplate
2
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
3
+ import os
4
+ from app.schema import OutputResponse
5
+
6
+ # We completely remove `getpass` to prevent blocking the Docker container.
7
+ # HuggingFace secrets should be defined in the HF space environment automatically.
8
+ if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
9
+ print("WARNING: HUGGINGFACEHUB_API_TOKEN is not set in the environment. "
10
+ "Set this as a secret in your HuggingFace Space or .env file.")
11
+
12
+ def load_model(repo_id: str, max_length: int = 512, temperature: float = 0.5):
13
+ llm = HuggingFaceEndpoint(
14
+ repo_id=repo_id,
15
+ task="text-generation",
16
+ max_new_tokens=max_length,
17
+ do_sample=temperature > 0,
18
+ temperature=temperature if temperature > 0 else None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
+ return llm
21
+
22
+ def generate_answer(question: str, llm) -> OutputResponse:
23
+ try:
24
+ prompt = PromptTemplate(
25
+ input_variables=["question"],
26
+ template="""
27
+ You are a helpful assistant that provides concise and accurate answers to user questions.
28
+ Question: {question}
29
+ Answer Format:
30
+ Answer: <Your concise answer here>
31
+ Justification: <Why this query is relevant to the user's request>
32
+ """
33
+ )
34
+ chat_model = ChatHuggingFace(llm=llm)
35
+ structured_llm = chat_model.with_structured_output(OutputResponse)
36
+ result = structured_llm.invoke(prompt.format(question=question))
37
+ return result
38
+ except Exception as e:
39
+ return OutputResponse(answer="Error generating answer", justification=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/routes.py CHANGED
@@ -1,131 +1,24 @@
1
- # routes.py
2
- import logging
3
- import time
4
- from typing import Optional
5
 
6
- from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks, Header
7
- from fastapi.responses import JSONResponse
8
 
9
- from schema import (
10
- RAGQueryRequest, RAGQueryResponse,
11
- IngestRequest, IngestResponse,
12
- HealthResponse, ErrorResponse, UserStatsResponse,
13
- )
14
- from services import (
15
- process_rag_query, ingest_documents,
16
- get_global_stats, get_user_stats,
17
- )
18
- from model import get_model, get_processor
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
- router_rag = APIRouter(prefix="/rag", tags=["RAG"])
23
- router_ingest = APIRouter(prefix="/ingest", tags=["Ingestion"])
24
- router_monitor = APIRouter(prefix="/monitor", tags=["Monitoring"])
25
- router_health = APIRouter(tags=["Health"])
26
-
27
- _start_time = time.time()
28
-
29
-
30
- # ── Auth dependency (replace with JWT/OAuth in production) ───────────────────
31
- async def verify_api_key(x_api_key: Optional[str] = Header(default=None)):
32
- import os
33
- expected = os.getenv("API_KEY", "dev-secret")
34
- if x_api_key != expected:
35
- raise HTTPException(status_code=401, detail="Invalid or missing API key.")
36
- return x_api_key
37
-
38
-
39
- # ── RAG Routes ────────────────────────────────────────────────────────────────
40
- @router_rag.post(
41
- "/query",
42
- response_model=RAGQueryResponse,
43
- responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
44
- summary="Multimodal RAG Query",
45
- description="Submit text, image, audio, or video inputs to query the RAG pipeline.",
46
- )
47
- async def rag_query(
48
- request: RAGQueryRequest,
49
- background_tasks: BackgroundTasks,
50
- _: str = Depends(verify_api_key),
51
- ):
52
  try:
53
- response = await process_rag_query(request)
54
- return response
55
- except ValueError as e:
56
- raise HTTPException(status_code=400, detail=str(e))
57
  except Exception as e:
58
- logger.exception(f"RAG query failed: {e}")
59
- raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
60
-
61
 
62
- # ── Ingestion Routes ──────────────────────────────────────────────────────────
63
- @router_ingest.post(
64
- "/documents",
65
- response_model=IngestResponse,
66
- summary="Ingest Multimodal Documents",
67
- )
68
- async def ingest_docs(
69
- request: IngestRequest,
70
- _: str = Depends(verify_api_key),
71
- ):
72
  try:
73
- doc_ids = ingest_documents(
74
- user_id=request.user_id,
75
- documents=request.documents,
76
- doc_ids=request.doc_ids,
77
- )
78
- return IngestResponse(
79
- ingested_count=len(doc_ids),
80
- doc_ids=doc_ids,
81
- message=f"Successfully ingested {len(doc_ids)} documents.",
82
- )
83
  except Exception as e:
84
- logger.exception(f"Ingestion failed: {e}")
85
- raise HTTPException(status_code=500, detail=str(e))
86
-
87
-
88
- # ── Monitoring Routes ─────────────────────────────────────────────────────────
89
- @router_monitor.get(
90
- "/health",
91
- response_model=HealthResponse,
92
- summary="API Health Check",
93
- )
94
- async def health_check():
95
- try:
96
- model = get_model()
97
- device = str(next(model.parameters()).device)
98
- model_loaded = True
99
- except RuntimeError:
100
- device = "unavailable"
101
- model_loaded = False
102
-
103
- stats = get_global_stats()
104
- return HealthResponse(
105
- status="ok" if model_loaded else "degraded",
106
- model_loaded=model_loaded,
107
- device=device,
108
- uptime_seconds=time.time() - _start_time,
109
- total_requests=stats["total_requests"],
110
- total_tokens_processed=stats["total_tokens"],
111
- )
112
-
113
-
114
- @router_monitor.get(
115
- "/users/{user_id}",
116
- response_model=UserStatsResponse,
117
- summary="Get Per-User Usage Stats",
118
- )
119
- async def user_stats(
120
- user_id: str,
121
- _: str = Depends(verify_api_key),
122
- ):
123
- stats = get_user_stats(user_id)
124
- if stats is None:
125
- raise HTTPException(status_code=404, detail=f"User '{user_id}' not found.")
126
- return UserStatsResponse(**stats)
127
-
128
-
129
- @router_health.get("/", include_in_schema=False)
130
- async def root():
131
- return {"message": "Multimodal RAG API is running. Visit /docs for API reference."}
 
1
+ from fastapi import APIRouter, Request
2
+ from app.model import generate_answer
3
+ from app.schema import UserRequest
 
4
 
5
+ router = APIRouter()
 
6
 
7
+ @router.get("/health", tags=["Health"])
8
+ async def health_check():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  try:
10
+ return {"status": "ok"}
 
 
 
11
  except Exception as e:
12
+ return {"status": "error", "message": str(e)}
 
 
13
 
14
+ @router.post("/chat", tags=["Chat"])
15
+ async def chat_endpoint(request_body: UserRequest, request: Request):
 
 
 
 
 
 
 
 
16
  try:
17
+ llm = request.app.state.llm
18
+ if not llm:
19
+ return {"status": "error", "message": "LLM not loaded into application state."}
20
+
21
+ response = generate_answer(request_body.question, llm)
22
+ return response
 
 
 
 
23
  except Exception as e:
24
+ return {"status": "error", "message": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/schema.py CHANGED
@@ -1,107 +1,9 @@
1
- # schema.py
2
- from pydantic import BaseModel, Field, validator
3
- from typing import Optional, List, Literal, Any, Dict
4
- from enum import Enum
5
- import uuid
6
- from datetime import datetime
7
-
8
-
9
- class ModalityType(str, Enum):
10
- TEXT = "text"
11
- IMAGE = "image"
12
- AUDIO = "audio"
13
- VIDEO = "video"
14
-
15
-
16
- class MediaInput(BaseModel):
17
- modality: ModalityType
18
- content: str = Field(..., description="URL, base64 string, or raw text depending on modality")
19
- use_audio_in_video: Optional[bool] = Field(default=True, description="Use embedded audio in video")
20
-
21
-
22
- class RAGQueryRequest(BaseModel):
23
- query_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
24
- user_id: str = Field(..., description="Unique user identifier")
25
- query_text: Optional[str] = Field(default=None, description="Natural language query")
26
- media_inputs: Optional[List[MediaInput]] = Field(default=[], description="List of multimodal inputs")
27
- top_k: int = Field(default=5, ge=1, le=20, description="Number of RAG context chunks to retrieve")
28
- return_audio: bool = Field(
29
- default=False,
30
- description="Audio output (GPU only — disabled on CPU deployments)")
31
- speaker: Literal["Chelsie", "Ethan"] = Field(default="Chelsie")
32
- max_new_tokens: int = Field(default=256, ge=32, le=512)
33
- temperature: float = Field(default=0.7, ge=0.0, le=2.0)
34
-
35
- @validator("media_inputs", always=True)
36
- def validate_at_least_one_input(cls, v, values):
37
- if not v and not values.get("query_text"):
38
- raise ValueError("At least one of query_text or media_inputs must be provided.")
39
- return v
40
-
41
-
42
- class RAGDocument(BaseModel):
43
- doc_id: str
44
- content: str
45
- modality: ModalityType
46
- score: float = Field(..., ge=0.0, le=1.0)
47
- metadata: Optional[Dict[str, Any]] = {}
48
-
49
-
50
- class TokenUsage(BaseModel):
51
- prompt_tokens: int
52
- completion_tokens: int
53
- total_tokens: int
54
-
55
-
56
- class PerformanceMetrics(BaseModel):
57
- latency_ms: float
58
- retrieval_latency_ms: float
59
- generation_latency_ms: float
60
- throughput_tokens_per_sec: float
61
-
62
-
63
- class RAGQueryResponse(BaseModel):
64
- query_id: str
65
- user_id: str
66
- answer_text: str
67
- retrieved_documents: List[RAGDocument]
68
- audio_base64: Optional[str] = None
69
- token_usage: TokenUsage
70
- performance: PerformanceMetrics
71
- timestamp: datetime = Field(default_factory=datetime.utcnow)
72
-
73
-
74
- class IngestRequest(BaseModel):
75
- user_id: str
76
- documents: List[MediaInput]
77
- doc_ids: Optional[List[str]] = None
78
-
79
-
80
- class IngestResponse(BaseModel):
81
- ingested_count: int
82
- doc_ids: List[str]
83
- message: str
84
-
85
-
86
- class HealthResponse(BaseModel):
87
- status: str
88
- model_loaded: bool
89
- device: str
90
- uptime_seconds: float
91
- total_requests: int
92
- total_tokens_processed: int
93
-
94
-
95
- class ErrorResponse(BaseModel):
96
- error: str
97
- detail: Optional[str] = None
98
- query_id: Optional[str] = None
99
- timestamp: datetime = Field(default_factory=datetime.utcnow)
100
-
101
-
102
- class UserStatsResponse(BaseModel):
103
- user_id: str
104
- total_queries: int
105
- total_tokens: int
106
- avg_latency_ms: float
107
- last_active: Optional[datetime] = None
 
1
+ from pydantic import BaseModel, Field
2
+
3
+ class UserRequest(BaseModel):
4
+ question: str = Field(..., description="The user's question or request.")
5
+ class OutputResponse(BaseModel):
6
+ answer: str = Field(..., description="The answer generated by the model.")
7
+ justification: str = Field(
8
+ ..., description="Why this query is relevant to the user's request."
9
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services.py DELETED
@@ -1,247 +0,0 @@
1
- # services.py
2
- import uuid
3
- import time
4
- import base64
5
- import logging
6
- import asyncio
7
- from datetime import datetime
8
- from collections import defaultdict
9
- from typing import List, Dict, Optional, Tuple, Any
10
-
11
- import numpy as np
12
- from sentence_transformers import SentenceTransformer
13
- from sklearn.metrics.pairwise import cosine_similarity
14
-
15
- from schema import (
16
- MediaInput, RAGDocument, TokenUsage, PerformanceMetrics,
17
- RAGQueryRequest, RAGQueryResponse, ModalityType
18
- )
19
- from model import run_inference
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
- # ── In-memory vector store (swap for FAISS/Qdrant/Chroma in production) ──────
24
- _doc_store: Dict[str, Dict[str, Any]] = {}
25
- _embeddings_store: Dict[str, np.ndarray] = {}
26
- _embed_model: Optional[SentenceTransformer] = None
27
-
28
- # ── Monitoring state ──────────────────────────────────────────────────────────
29
- _global_stats = {
30
- "total_requests": 0,
31
- "total_tokens": 0,
32
- "start_time": time.time(),
33
- }
34
- _user_stats: Dict[str, Dict] = defaultdict(lambda: {
35
- "total_queries": 0,
36
- "total_tokens": 0,
37
- "latencies": [],
38
- "last_active": None,
39
- })
40
-
41
-
42
- def get_embed_model() -> SentenceTransformer:
43
- global _embed_model
44
- if _embed_model is None:
45
- _embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
46
- logger.info("Embedding model loaded.")
47
- return _embed_model
48
-
49
-
50
- def _embed_text(text: str) -> np.ndarray:
51
- model = get_embed_model()
52
- return model.encode([text], normalize_embeddings=True)[0]
53
-
54
-
55
- def ingest_documents(
56
- user_id: str,
57
- documents: List[MediaInput],
58
- doc_ids: Optional[List[str]] = None,
59
- ) -> List[str]:
60
- ids = []
61
- for i, doc in enumerate(documents):
62
- doc_id = doc_ids[i] if doc_ids and i < len(doc_ids) else str(uuid.uuid4())
63
-
64
- # For non-text modalities, embed a descriptor; full multimodal embeddings
65
- # require a separate vision-language embedding model (e.g., CLIP, ImageBind).
66
- if doc.modality == ModalityType.TEXT:
67
- embed_text = doc.content
68
- else:
69
- embed_text = f"[{doc.modality.value.upper()} resource] {doc.content}"
70
-
71
- embedding = _embed_text(embed_text)
72
- _doc_store[doc_id] = {
73
- "doc_id": doc_id,
74
- "user_id": user_id,
75
- "content": doc.content,
76
- "modality": doc.modality,
77
- "metadata": {"ingested_at": datetime.utcnow().isoformat()},
78
- }
79
- _embeddings_store[doc_id] = embedding
80
- ids.append(doc_id)
81
-
82
- logger.info(f"Ingested {len(ids)} documents for user {user_id}")
83
- return ids
84
-
85
-
86
- def retrieve_documents(
87
- query_text: str,
88
- top_k: int = 5,
89
- ) -> List[RAGDocument]:
90
- if not _embeddings_store:
91
- return []
92
-
93
- query_emb = _embed_text(query_text).reshape(1, -1)
94
- doc_ids = list(_embeddings_store.keys())
95
- doc_embs = np.vstack([_embeddings_store[d] for d in doc_ids])
96
- scores = cosine_similarity(query_emb, doc_embs)[0]
97
-
98
- top_indices = np.argsort(scores)[::-1][:top_k]
99
- results = []
100
- for idx in top_indices:
101
- did = doc_ids[idx]
102
- doc = _doc_store[did]
103
- results.append(RAGDocument(
104
- doc_id=did,
105
- content=doc["content"],
106
- modality=doc["modality"],
107
- score=float(scores[idx]),
108
- metadata=doc.get("metadata", {}),
109
- ))
110
- return results
111
-
112
-
113
- def _build_rag_conversation(
114
- request: RAGQueryRequest,
115
- retrieved_docs: List[RAGDocument],
116
- ) -> List[Dict]:
117
- system_prompt = (
118
- "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
119
- "capable of perceiving auditory and visual inputs, as well as generating text and speech."
120
- if request.return_audio
121
- else "You are a helpful multimodal AI assistant with access to retrieved context."
122
- )
123
-
124
- context_str = "\n\n".join(
125
- [f"[Context {i+1} | {d.modality.value}]: {d.content}" for i, d in enumerate(retrieved_docs)]
126
- )
127
-
128
- user_content: List[Dict] = []
129
-
130
- for media in (request.media_inputs or []):
131
- if media.modality == ModalityType.TEXT:
132
- user_content.append({"type": "text", "text": media.content})
133
- elif media.modality == ModalityType.IMAGE:
134
- user_content.append({"type": "image", "image": media.content})
135
- elif media.modality == ModalityType.AUDIO:
136
- user_content.append({"type": "audio", "audio": media.content})
137
- elif media.modality == ModalityType.VIDEO:
138
- user_content.append({"type": "video", "video": media.content})
139
-
140
- final_query = (
141
- f"Retrieved Context:\n{context_str}\n\n"
142
- f"User Query: {request.query_text or 'Analyze the provided media.'}"
143
- )
144
- user_content.append({"type": "text", "text": final_query})
145
-
146
- return [
147
- {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
148
- {"role": "user", "content": user_content},
149
- ]
150
-
151
-
152
- def _update_monitoring(
153
- user_id: str,
154
- total_tokens: int,
155
- latency_ms: float,
156
- ):
157
- _global_stats["total_requests"] += 1
158
- _global_stats["total_tokens"] += total_tokens
159
-
160
- user = _user_stats[user_id]
161
- user["total_queries"] += 1
162
- user["total_tokens"] += total_tokens
163
- user["latencies"].append(latency_ms)
164
- # Keep only last 1000 latencies per user to avoid unbounded memory
165
- if len(user["latencies"]) > 1000:
166
- user["latencies"] = user["latencies"][-1000:]
167
- user["last_active"] = datetime.utcnow()
168
-
169
-
170
- async def process_rag_query(request: RAGQueryRequest) -> RAGQueryResponse:
171
- total_start = time.time()
172
-
173
- # ── Retrieval ─────────────────────────────────────────────────────────────
174
- retrieval_start = time.time()
175
- query_for_retrieval = request.query_text or " ".join(
176
- m.content for m in (request.media_inputs or []) if m.modality == ModalityType.TEXT
177
- ) or "multimodal query"
178
- retrieved_docs = retrieve_documents(query_for_retrieval, top_k=request.top_k)
179
- retrieval_latency_ms = (time.time() - retrieval_start) * 1000
180
-
181
- # ── Build conversation ────────────────────────────────────────────────────
182
- conversation = _build_rag_conversation(request, retrieved_docs)
183
-
184
- use_audio_in_video = any(
185
- m.use_audio_in_video for m in (request.media_inputs or [])
186
- if m.modality == ModalityType.VIDEO
187
- )
188
-
189
- # ── Generation (run in thread pool to avoid blocking event loop) ──────────
190
- gen_start = time.time()
191
- loop = asyncio.get_event_loop()
192
- answer, audio_bytes, prompt_tokens, completion_tokens = await loop.run_in_executor(
193
- None,
194
- lambda: run_inference(
195
- conversation=conversation,
196
- return_audio=request.return_audio,
197
- speaker=request.speaker,
198
- max_new_tokens=request.max_new_tokens,
199
- temperature=request.temperature,
200
- use_audio_in_video=use_audio_in_video,
201
- ),
202
- )
203
- gen_latency_ms = (time.time() - gen_start) * 1000
204
-
205
- total_latency_ms = (time.time() - total_start) * 1000
206
- total_tokens = prompt_tokens + completion_tokens
207
- throughput = (completion_tokens / (gen_latency_ms / 1000)) if gen_latency_ms > 0 else 0
208
-
209
- _update_monitoring(request.user_id, total_tokens, total_latency_ms)
210
-
211
- audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") if audio_bytes else None
212
-
213
- return RAGQueryResponse(
214
- query_id=request.query_id,
215
- user_id=request.user_id,
216
- answer_text=answer,
217
- retrieved_documents=retrieved_docs,
218
- audio_base64=audio_b64,
219
- token_usage=TokenUsage(
220
- prompt_tokens=prompt_tokens,
221
- completion_tokens=completion_tokens,
222
- total_tokens=total_tokens,
223
- ),
224
- performance=PerformanceMetrics(
225
- latency_ms=total_latency_ms,
226
- retrieval_latency_ms=retrieval_latency_ms,
227
- generation_latency_ms=gen_latency_ms,
228
- throughput_tokens_per_sec=throughput,
229
- ),
230
- )
231
-
232
-
233
- def get_global_stats() -> Dict:
234
- return _global_stats.copy()
235
-
236
-
237
- def get_user_stats(user_id: str) -> Optional[Dict]:
238
- if user_id not in _user_stats:
239
- return None
240
- u = _user_stats[user_id]
241
- return {
242
- "user_id": user_id,
243
- "total_queries": u["total_queries"],
244
- "total_tokens": u["total_tokens"],
245
- "avg_latency_ms": float(np.mean(u["latencies"])) if u["latencies"] else 0.0,
246
- "last_active": u["last_active"],
247
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,13 +1,7 @@
1
- fastapi>=0.111.0
2
- uvicorn[standard]>=0.29.0
3
- uvloop>=0.19.0
4
- pydantic>=2.7.0
5
- transformers @ git+https://github.com/huggingface/transformers@v4.51.3-Qwen2.5-Omni-preview
6
- accelerate>=0.30.0
7
- qwen-omni-utils[decord]
8
- sentence-transformers>=3.0.0
9
- scikit-learn>=1.4.0
10
- soundfile>=0.12.1
11
- numpy>=1.26.0
12
- slowapi>=0.1.9
13
- # flash-attn REMOVED — requires nvcc/GPU to compile
 
1
+ fastapi[standard]
2
+ langchain_core
3
+ langgraph
4
+ huggingface_hub
5
+ langchain
6
+ langchain-huggingface
7
+ pydantic