import datetime import io import secrets from bson.objectid import ObjectId from ctransformers import AutoModelForCausalLM from fastapi import FastAPI, Request, Depends, HTTPException, status from fastapi.responses import StreamingResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials from motor.motor_asyncio import AsyncIOMotorClient from pydantic import BaseModel from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address import logging # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') # Initialize the limiter limiter = Limiter(key_func=get_remote_address) # MongoDB client client = AsyncIOMotorClient('mongodb+srv://vanaraai:0YzPmeArBjsIVKqd@jobcluster.thp8ohx.mongodb.net/?retryWrites=true&w=majority') db = client['starlight_tales'] # Load the model llm = AutoModelForCausalLM.from_pretrained( "tinyllama-bedtimestories-f16.gguf", model_type='llama', max_new_tokens=2000, threads=3, context_length = 4096, ) # FastAPI app app = FastAPI() # Add an event handler for rate limit exceeded app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) security = HTTPBasic() class Story(BaseModel): story_content: str class Feedback(BaseModel): story_id: str liked: bool class Validation(BaseModel): prompt: str def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)): correct_username = secrets.compare_digest(credentials.username, "admin") correct_password = secrets.compare_digest(credentials.password, "test123") if not (correct_username and correct_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}, ) return credentials.username @app.post("/llm_on_cpu", dependencies=[Depends(verify_credentials)]) @limiter.limit("10/minute") async def stream(request: Request, item: Validation): """ Generate a bedtime story using the LLM model. Args: request (Request): The incoming request. item (Validation): The input validation data. Returns: dict: A dictionary containing the generated story content and the timestamp of generation. """ logging.info('Generating story') prompt = "<|user|>\nA bedtime tiny story for the children<\/s>\n<|assistant|>\n" story_content = llm(prompt) story_dict = { "story_content": story_content, "generated_at": datetime.datetime.utcnow() } result = await db.stories.insert_one(story_dict) logging.info(f'Story generated with id {result.inserted_id}') return {"_id": str(result.inserted_id), "story_content": story_content} @app.post("/llm_on_cpu_stream") async def stream_stories(item: Validation): """ Stream bedtime stories for children. This function takes a Validation object as input and streams bedtime stories for children. It uses the LLM model to generate the stories based on a given prompt. Parameters: item (Validation): The Validation object containing the input data. Returns: StreamingResponse: The streaming response containing the generated stories. """ prompt = "<|user|>\nA bedtime tiny story for the children<\/s>\n<|assistant|>\n" return StreamingResponse(io.StringIO(llm(prompt)), media_type="text/plain") #working code for streaming # for text in llm(f"{prompt}", stream=True): # print(text, end="", flush=True) # #return StreamingResponse(llm(prompt,stream=True), media_type="text/event-stream") @app.post("/feedback") async def create_feedback(feedback: Feedback): """ Create feedback and insert it into the database. Parameters: - feedback (Feedback): The feedback object to be created. Returns: - dict: A dictionary containing the ID of the inserted feedback. """ logging.info('Receiving feedback') feedback_dict = feedback.dict() feedback_dict["feedback_at"] = datetime.datetime.utcnow() feedback_dict["story_id"] = ObjectId(feedback.story_id) result = await db.feedback.insert_one(feedback_dict) logging.info(f'Feedback received with id {result.inserted_id}') return {"_id": str(result.inserted_id)}