Spaces:
Runtime error
Runtime error
| import datetime | |
| import io | |
| import secrets | |
| from bson.objectid import ObjectId | |
| from ctransformers import AutoModelForCausalLM | |
| from fastapi import FastAPI, Request, Depends, HTTPException, status | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.security import HTTPBasic, HTTPBasicCredentials | |
| from motor.motor_asyncio import AsyncIOMotorClient | |
| from pydantic import BaseModel | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.errors import RateLimitExceeded | |
| from slowapi.util import get_remote_address | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| # Initialize the limiter | |
| limiter = Limiter(key_func=get_remote_address) | |
| # MongoDB client | |
| client = AsyncIOMotorClient('mongodb+srv://vanaraai:0YzPmeArBjsIVKqd@jobcluster.thp8ohx.mongodb.net/?retryWrites=true&w=majority') | |
| db = client['starlight_tales'] | |
| # Load the model | |
| llm = AutoModelForCausalLM.from_pretrained( | |
| "tinyllama-bedtimestories-f16.gguf", | |
| model_type='llama', | |
| max_new_tokens=2000, | |
| threads=3, | |
| context_length = 4096, | |
| ) | |
| # FastAPI app | |
| app = FastAPI() | |
| # Add an event handler for rate limit exceeded | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| security = HTTPBasic() | |
| class Story(BaseModel): | |
| story_content: str | |
| class Feedback(BaseModel): | |
| story_id: str | |
| liked: bool | |
| class Validation(BaseModel): | |
| prompt: str | |
| def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)): | |
| correct_username = secrets.compare_digest(credentials.username, "admin") | |
| correct_password = secrets.compare_digest(credentials.password, "test123") | |
| if not (correct_username and correct_password): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect username or password", | |
| headers={"WWW-Authenticate": "Basic"}, | |
| ) | |
| return credentials.username | |
| async def stream(request: Request, item: Validation): | |
| """ | |
| Generate a bedtime story using the LLM model. | |
| Args: | |
| request (Request): The incoming request. | |
| item (Validation): The input validation data. | |
| Returns: | |
| dict: A dictionary containing the generated story content and the timestamp of generation. | |
| """ | |
| logging.info('Generating story') | |
| prompt = "<|user|>\nA bedtime tiny story for the children<\/s>\n<|assistant|>\n" | |
| story_content = llm(prompt) | |
| story_dict = { | |
| "story_content": story_content, | |
| "generated_at": datetime.datetime.utcnow() | |
| } | |
| result = await db.stories.insert_one(story_dict) | |
| logging.info(f'Story generated with id {result.inserted_id}') | |
| return {"_id": str(result.inserted_id), "story_content": story_content} | |
| async def stream_stories(item: Validation): | |
| """ | |
| Stream bedtime stories for children. | |
| This function takes a Validation object as input and streams bedtime stories for children. | |
| It uses the LLM model to generate the stories based on a given prompt. | |
| Parameters: | |
| item (Validation): The Validation object containing the input data. | |
| Returns: | |
| StreamingResponse: The streaming response containing the generated stories. | |
| """ | |
| prompt = "<|user|>\nA bedtime tiny story for the children<\/s>\n<|assistant|>\n" | |
| return StreamingResponse(io.StringIO(llm(prompt)), media_type="text/plain") | |
| #working code for streaming | |
| # for text in llm(f"{prompt}", stream=True): | |
| # print(text, end="", flush=True) | |
| # #return StreamingResponse(llm(prompt,stream=True), media_type="text/event-stream") | |
| async def create_feedback(feedback: Feedback): | |
| """ | |
| Create feedback and insert it into the database. | |
| Parameters: | |
| - feedback (Feedback): The feedback object to be created. | |
| Returns: | |
| - dict: A dictionary containing the ID of the inserted feedback. | |
| """ | |
| logging.info('Receiving feedback') | |
| feedback_dict = feedback.dict() | |
| feedback_dict["feedback_at"] = datetime.datetime.utcnow() | |
| feedback_dict["story_id"] = ObjectId(feedback.story_id) | |
| result = await db.feedback.insert_one(feedback_dict) | |
| logging.info(f'Feedback received with id {result.inserted_id}') | |
| return {"_id": str(result.inserted_id)} |