shahid202 commited on
Commit
5ca3e2f
·
verified ·
1 Parent(s): 0659f06

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +26 -9
main.py CHANGED
@@ -1,22 +1,39 @@
1
  from fastapi import FastAPI, WebSocket
 
 
2
  from transformers import pipeline
3
  from kokoro import KPipeline
4
  import asyncio
 
5
 
6
  app = FastAPI()
 
 
 
 
 
7
  llm = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-360M-Instruct")
8
  tts = KPipeline(lang_code='a', model='shahid202/Kokoro-82M-TTS')
9
 
 
 
 
 
10
  @app.websocket("/ws/chat")
11
  async def websocket_endpoint(websocket: WebSocket):
12
  await websocket.accept()
13
  while True:
14
- user_msg = await websocket.receive_text()
15
- # Get LLM text
16
- response_text = llm(f"User: {user_msg}\nBella:", max_new_tokens=30)[0]['generated_text']
17
- text = response_text.split("Bella:")[-1].strip()
18
-
19
- # Stream audio chunk by chunk
20
- for _, _, audio in tts(text, voice="af_heart", speed=1.0):
21
- await websocket.send_bytes(audio.tobytes())
22
- await asyncio.sleep(0.05) # Keeps flow smooth
 
 
 
 
 
 
1
  from fastapi import FastAPI, WebSocket
2
+ from fastapi.responses import FileResponse
3
+ from fastapi.staticfiles import StaticFiles
4
  from transformers import pipeline
5
  from kokoro import KPipeline
6
  import asyncio
7
+ import numpy as np
8
 
9
  app = FastAPI()
10
+
11
+ # Mount the static folder
12
+ app.mount("/static", StaticFiles(directory="static"), name="static")
13
+
14
+ # Initialize models
15
  llm = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-360M-Instruct")
16
  tts = KPipeline(lang_code='a', model='shahid202/Kokoro-82M-TTS')
17
 
18
+ @app.get("/")
19
+ async def get_index():
20
+ return FileResponse('static/index.html')
21
+
22
  @app.websocket("/ws/chat")
23
  async def websocket_endpoint(websocket: WebSocket):
24
  await websocket.accept()
25
  while True:
26
+ try:
27
+ user_msg = await websocket.receive_text()
28
+ # Generate response
29
+ response = llm(f"User: {user_msg}\nBella:", max_new_tokens=30)[0]['generated_text']
30
+ text = response.split("Bella:")[-1].strip()
31
+
32
+ # Stream audio
33
+ for _, _, audio in tts(text, voice="af_heart", speed=1.0):
34
+ # Ensure float32 format
35
+ await websocket.send_bytes(audio.astype(np.float32).tobytes())
36
+ await asyncio.sleep(0.01)
37
+ except Exception as e:
38
+ print(f"Error: {e}")
39
+ break