tinyllama-test / main.py
shashank-indukuri's picture
implement logging
4d2d90e verified
import datetime
import io
import secrets
from bson.objectid import ObjectId
from ctransformers import AutoModelForCausalLM
from fastapi import FastAPI, Request, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from motor.motor_asyncio import AsyncIOMotorClient
from pydantic import BaseModel
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Initialize the limiter
limiter = Limiter(key_func=get_remote_address)
# MongoDB client
client = AsyncIOMotorClient('mongodb+srv://vanaraai:0YzPmeArBjsIVKqd@jobcluster.thp8ohx.mongodb.net/?retryWrites=true&w=majority')
db = client['starlight_tales']
# Load the model
llm = AutoModelForCausalLM.from_pretrained(
"tinyllama-bedtimestories-f16.gguf",
model_type='llama',
max_new_tokens=2000,
threads=3,
context_length = 4096,
)
# FastAPI app
app = FastAPI()
# Add an event handler for rate limit exceeded
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
security = HTTPBasic()
class Story(BaseModel):
story_content: str
class Feedback(BaseModel):
story_id: str
liked: bool
class Validation(BaseModel):
prompt: str
def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)):
correct_username = secrets.compare_digest(credentials.username, "admin")
correct_password = secrets.compare_digest(credentials.password, "test123")
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return credentials.username
@app.post("/llm_on_cpu", dependencies=[Depends(verify_credentials)])
@limiter.limit("10/minute")
async def stream(request: Request, item: Validation):
"""
Generate a bedtime story using the LLM model.
Args:
request (Request): The incoming request.
item (Validation): The input validation data.
Returns:
dict: A dictionary containing the generated story content and the timestamp of generation.
"""
logging.info('Generating story')
prompt = "<|user|>\nA bedtime tiny story for the children<\/s>\n<|assistant|>\n"
story_content = llm(prompt)
story_dict = {
"story_content": story_content,
"generated_at": datetime.datetime.utcnow()
}
result = await db.stories.insert_one(story_dict)
logging.info(f'Story generated with id {result.inserted_id}')
return {"_id": str(result.inserted_id), "story_content": story_content}
@app.post("/llm_on_cpu_stream")
async def stream_stories(item: Validation):
"""
Stream bedtime stories for children.
This function takes a Validation object as input and streams bedtime stories for children.
It uses the LLM model to generate the stories based on a given prompt.
Parameters:
item (Validation): The Validation object containing the input data.
Returns:
StreamingResponse: The streaming response containing the generated stories.
"""
prompt = "<|user|>\nA bedtime tiny story for the children<\/s>\n<|assistant|>\n"
return StreamingResponse(io.StringIO(llm(prompt)), media_type="text/plain")
#working code for streaming
# for text in llm(f"{prompt}", stream=True):
# print(text, end="", flush=True)
# #return StreamingResponse(llm(prompt,stream=True), media_type="text/event-stream")
@app.post("/feedback")
async def create_feedback(feedback: Feedback):
"""
Create feedback and insert it into the database.
Parameters:
- feedback (Feedback): The feedback object to be created.
Returns:
- dict: A dictionary containing the ID of the inserted feedback.
"""
logging.info('Receiving feedback')
feedback_dict = feedback.dict()
feedback_dict["feedback_at"] = datetime.datetime.utcnow()
feedback_dict["story_id"] = ObjectId(feedback.story_id)
result = await db.feedback.insert_one(feedback_dict)
logging.info(f'Feedback received with id {result.inserted_id}')
return {"_id": str(result.inserted_id)}