Spaces:
Sleeping
Sleeping
| # app/main.py | |
| import os | |
| import json | |
| import logging | |
| import asyncio | |
| from typing import Optional | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from .model import PlutusModel, SummaryModel | |
| from .recommender import Recommender | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("plutus.api") | |
| _CACHE_DIR = os.getenv("HF_HOME", "/home/user/app") | |
| DEFAULT_RECOMMEND_JSON = os.getenv(_CACHE_DIR, "recommend.json") | |
| RECOMMEND_INDEX_PATH = os.path.join(_CACHE_DIR, "plutus_recommend_index.faiss") | |
| RECOMMEND_META_PATH = os.path.join(_CACHE_DIR, "plutus_recommend_meta.json") | |
| class GenerateCache: | |
| last_query: Optional[str] = None | |
| last_topic: Optional[str] = None | |
| last_personality: Optional[str] = None | |
| last_level: Optional[str] = None | |
| last_output: Optional[str] = None | |
| GEN_CACHE = GenerateCache() | |
| logger.info("Loading shared Plutus LLM and recommender...") | |
| plutus_model = PlutusModel() | |
| summary_model = SummaryModel() | |
| recommender = Recommender( | |
| recommend_json_path=DEFAULT_RECOMMEND_JSON, | |
| index_path=RECOMMEND_INDEX_PATH, | |
| meta_path=RECOMMEND_META_PATH | |
| ) | |
| app = FastAPI(title="Plutus Learner API") | |
| class GenerateRequest(BaseModel): | |
| personality: str | |
| level: str | |
| topic: str | |
| query: str | |
| max_new_tokens: int = 700 | |
| temperature: float = 0.5 | |
| top_p: float = 0.9 | |
| class RecommendRequest(BaseModel): | |
| top_k: int = 5 | |
| class SummaryRequest(BaseModel): | |
| top_k: int = 5 | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "device": plutus_model.device | |
| } | |
| async def generate(req: GenerateRequest): | |
| prompt = plutus_model.create_prompt( | |
| req.personality, | |
| req.level, | |
| req.topic, | |
| req.query | |
| ) | |
| async def event_generator(): | |
| full_text = "" | |
| for chunk in plutus_model.generate( | |
| prompt, | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p | |
| ): | |
| full_text += chunk + "\n" | |
| yield f"data: {json.dumps({'text': chunk})}\n\n" | |
| await asyncio.sleep(0) | |
| # Cache final result | |
| GEN_CACHE.last_query = req.query | |
| GEN_CACHE.last_topic = req.topic | |
| GEN_CACHE.last_personality = req.personality | |
| GEN_CACHE.last_level = req.level | |
| GEN_CACHE.last_output = full_text.strip() | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream" | |
| ) | |
| async def recommend(req: RecommendRequest): | |
| if GEN_CACHE.last_query is None: | |
| raise HTTPException(400, "No query found. Call /generate first.") | |
| results = recommender.recommend_for_query( | |
| query=GEN_CACHE.last_query, | |
| top_k=req.top_k, | |
| topic_boost=GEN_CACHE.last_topic | |
| ) | |
| return { | |
| "query": GEN_CACHE.last_query, | |
| "results": [ | |
| {"topic": r["topic"], "type": r["type"], "url": r["url"]} | |
| for r in results | |
| ] | |
| } | |
| async def summary(req: SummaryRequest): | |
| if GEN_CACHE.last_output is None: | |
| raise HTTPException(400, "No generate output found. Call /generate first.") | |
| recs = recommender.recommend_for_query( | |
| query=GEN_CACHE.last_query, | |
| top_k=req.top_k, | |
| topic_boost=GEN_CACHE.last_topic | |
| ) | |
| async def event_generator(): | |
| for chunk in summary_model.summarize_text( | |
| full_teaching=GEN_CACHE.last_output, | |
| topic=GEN_CACHE.last_topic, | |
| level=GEN_CACHE.last_level, | |
| recommended=recs, | |
| max_new_tokens=300 | |
| ): | |
| yield f"data: {json.dumps({'summary': chunk})}\n\n" | |
| await asyncio.sleep(0) | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream" | |
| ) | |
| async def build_index(force: bool = False): | |
| recommender.build_index(force=force) | |
| return {"indexed": len(recommender.meta)} | |
| if __name__ == "__main__": | |
| uvicorn.run("app.main:app", host="0.0.0.0", port=7860) | |