Testing / main.py
shahid202's picture
Update main.py
9323c84 verified
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from transformers import pipeline
from kokoro import KPipeline
import asyncio
import numpy as np
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
# Initialize pipelines
llm = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-360M-Instruct")
tts = KPipeline(lang_code='a', model='shahid202/Kokoro-82M-TTS')
@app.get("/")
async def get_index():
return FileResponse('static/index.html')
@app.websocket("/ws/chat")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
try:
user_msg = await websocket.receive_text()
# Generate LLM response
output = llm(f"User: {user_msg}\nBella:", max_new_tokens=30, clean_up_tokenization_spaces=False)
full_text = output[0]['generated_text']
text = full_text.split("Bella:")[-1].strip()
# Stream audio chunks
for _, _, audio in tts(text, voice="af_heart", speed=1.0):
await websocket.send_bytes(audio.astype(np.float32).tobytes())
await asyncio.sleep(0.01) # Small buffer delay for stability
except Exception as e:
print(f"Error in WebSocket: {e}")
break