from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Dict, List, Any, Optional from datasets import load_dataset import os import json # ========================= # CONFIG # ========================= HF_TOKEN = os.getenv("HF_TOKEN", None) DATASETS = { "matrix_game": "Skywork/Matrix-Game-3.0", "framework_storage": "sky-meilin/SkyAiFramework-storage", } # Optional: später S3 / HF storage / custom bucket connector MEMORY_BUCKET = "sky-meilin/SkyAiFramework-storage" API_VERSION = "1.0.1_health_status_fix_2026_06_09" app = FastAPI( title="SkyAI Dataset Router + Memory Injection System", version=API_VERSION, ) # ========================= # MEMORY LAYER # Simple in-memory first stage. # Later replaceable by DB / Vector DB / HF Dataset storage. # ========================= MEMORY_STORE: Dict[str, List[Dict[str, Any]]] = {} # ========================= # MODELS # ========================= class QueryRequest(BaseModel): dataset: str query: Optional[str] = None limit: int = 5 class MemoryInjectRequest(BaseModel): user_id: str content: Dict[str, Any] class MemoryQueryRequest(BaseModel): user_id: str query: Optional[str] = None class PipelineRequest(BaseModel): user_id: Optional[str] = "default" dataset: Optional[str] = "matrix_game" query: Optional[str] = None limit: Optional[int] = 5 # ========================= # DATASET ROUTER # ========================= def load_selected_dataset(name: str): if name not in DATASETS: raise ValueError(f"Dataset '{name}' not registered.") try: dataset_id = DATASETS[name] ds = load_dataset(dataset_id, token=HF_TOKEN) return ds except Exception as e: raise RuntimeError(f"Dataset load failed: {str(e)}") def route_dataset(name: str, query: Optional[str], limit: int): dataset = load_selected_dataset(name) # Standard: first split. split_name = list(dataset.keys())[0] data = dataset[split_name] results = [] # Simple filter logic. Later replaceable with embeddings / vector search. for item in data: if query: text_blob = json.dumps(item, ensure_ascii=False).lower() if query.lower() in text_blob: results.append(item) else: results.append(item) if len(results) >= limit: break return results # ========================= # MEMORY SYSTEM # ========================= def inject_memory(user_id: str, context: Dict[str, Any]): if user_id not in MEMORY_STORE: MEMORY_STORE[user_id] = [] MEMORY_STORE[user_id].append(context) def retrieve_memory(user_id: str, query: Optional[str] = None): if user_id not in MEMORY_STORE: return [] memories = MEMORY_STORE[user_id] if not query: return memories[-10:] filtered = [] for memory_item in memories: text_blob = json.dumps(memory_item, ensure_ascii=False).lower() if query.lower() in text_blob: filtered.append(memory_item) return filtered # ========================= # API ENDPOINTS # ========================= @app.get("/") def root(): return { "status": "SkyAI Dataset Router active", "service": "SkyAI Dataset Router + Memory Injection System", "version": API_VERSION, "datasets": list(DATASETS.keys()), "memory_bucket": MEMORY_BUCKET, "endpoints": [ "/", "/health", "/status", "/dataset/query", "/memory/inject", "/memory/query", "/pipeline/run", ], } @app.get("/health") def health(): return { "status": "ok", "service": "SkyAI Dataset Router + Memory Injection System", "version": API_VERSION, } @app.get("/status") def api_status(): return { "status": "active", "service": "SkyAI Dataset Router + Memory Injection System", "version": API_VERSION, "datasets": list(DATASETS.keys()), "memory_bucket": MEMORY_BUCKET, "memory_users": len(MEMORY_STORE), "endpoints": [ "/", "/health", "/status", "/dataset/query", "/memory/inject", "/memory/query", "/pipeline/run", ], } @app.post("/dataset/query") def dataset_query(req: QueryRequest): try: results = route_dataset(req.dataset, req.query, req.limit) return { "success": True, "dataset": req.dataset, "query": req.query, "count": len(results), "results": results, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/memory/inject") def memory_inject(req: MemoryInjectRequest): try: inject_memory(req.user_id, req.content) return { "success": True, "status": "memory_stored", "user_id": req.user_id, "total": len(MEMORY_STORE.get(req.user_id, [])), } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/memory/query") def memory_query(req: MemoryQueryRequest): try: results = retrieve_memory(req.user_id, req.query) return { "success": True, "user_id": req.user_id, "query": req.query, "count": len(results), "memories": results, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ========================= # COMBINED PIPELINE # Dataset + Memory Injection # ========================= @app.post("/pipeline/run") def pipeline(req: PipelineRequest): try: user_id = req.user_id or "default" dataset_name = req.dataset or "matrix_game" query = req.query limit = req.limit or 5 dataset_results = route_dataset(dataset_name, query, limit=limit) inject_memory( user_id, { "query": query, "dataset": dataset_name, "results_count": len(dataset_results), "pipeline": "dataset_memory_injection", }, ) memory = retrieve_memory(user_id) return { "success": True, "dataset": dataset_name, "query": query, "dataset_results": dataset_results, "memory_context": memory, } except Exception as e: raise HTTPException(status_code=500, detail=str(e))