CDOM201 commited on
Commit
5c31db9
·
verified ·
1 Parent(s): eb90de9

Upload 6 files

Browse files
Files changed (6) hide show
  1. .dockerignore +9 -0
  2. Dockerfile +42 -0
  3. app.py +23 -0
  4. download_model.py +8 -0
  5. main.py +74 -0
  6. requirements.txt +11 -0
.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ *.wav
4
+ .git/
5
+ .gitignore
6
+ .env
7
+ testing.js
8
+ node_modules/
9
+
Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official PyTorch image with CUDA support
2
+ FROM pytorch/pytorch:2.5.1-cuda11.8-cudnn9-runtime
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PYTHONDONTWRITEBYTECODE=1 \
7
+ PORT=7860
8
+
9
+ # Set the working directory in the container
10
+ WORKDIR /app
11
+
12
+ # Install system dependencies
13
+ RUN apt-get update && apt-get install -y \
14
+ ffmpeg \
15
+ libsndfile1 \
16
+ git \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Install Python dependencies
20
+ COPY requirements.txt .
21
+ RUN pip install --no-cache-dir --upgrade pip && \
22
+ pip install --no-cache-dir -r requirements.txt
23
+
24
+ # Create a non-root user and switch to it
25
+ # Hugging Face Spaces runs as user 1000
26
+ RUN useradd -m -u 1000 user
27
+ USER user
28
+ ENV HOME=/home/user \
29
+ PATH=/home/user/.local/bin:$PATH
30
+
31
+ # Copy the rest of the application code
32
+ COPY --chown=user . .
33
+
34
+ # Pre-download the model weights during build time
35
+ RUN python download_model.py
36
+
37
+ # Expose the port (Hugging Face Spaces expects 7860)
38
+ EXPOSE 7860
39
+
40
+ # Command to run the application
41
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
42
+
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio as ta
2
+ import torch
3
+ from chatterbox.mtl_tts import ChatterboxMultilingualTTS
4
+ import functools
5
+
6
+ # torch.load = functools.partial(torch.load, map_location='cpu')
7
+
8
+ # device_map = torch.device('cpu')
9
+ device_map = None
10
+ if torch.cuda.is_available():
11
+ device_map = torch.device('cuda')
12
+ else:
13
+ device_map = torch.device('cpu')
14
+
15
+ print(f"Using device: {device_map}")
16
+
17
+ tts_model = ChatterboxMultilingualTTS.from_pretrained(device=device_map)
18
+ streamer_lang = "es"
19
+
20
+ msg = "CDOM201 dice: Como estas pandita, igual de puto como siempre?"
21
+ audio_file = tts_model.generate(msg, language_id=streamer_lang)
22
+
23
+ ta.save("sleeplespanda.wav", audio_file, tts_model.sr);
download_model.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from chatterbox.mtl_tts import ChatterboxMultilingualTTS
3
+
4
+ print("Downloading model...")
5
+ # We use cpu here just to download the weights to the cache during build time
6
+ ChatterboxMultilingualTTS.from_pretrained(device="cpu")
7
+ print("Model downloaded successfully.")
8
+
main.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio as ta
4
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
5
+ from fastapi.responses import FileResponse
6
+ from pydantic import BaseModel
7
+ from chatterbox.mtl_tts import ChatterboxMultilingualTTS
8
+ import functools
9
+ import uvicorn
10
+
11
+ # Patch torch.load for CPU if necessary (as in app.py)
12
+ # torch.load = functools.partial(torch.load, map_location='cpu')
13
+
14
+ app = FastAPI()
15
+
16
+ # 1. Determine device dynamically
17
+ device_map = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ print(f"CUDA Available: {torch.cuda.is_available()}")
20
+ print(f"Using device: {device_map} with name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
21
+
22
+ print("Loading TTS model...")
23
+ tts_model = ChatterboxMultilingualTTS.from_pretrained(device=device_map)
24
+ print("Model loaded.")
25
+
26
+ class TTSRequest(BaseModel):
27
+ message: str
28
+ language: str
29
+ channelID: str
30
+ username: str
31
+ messageid: str
32
+
33
+ def cleanup_file(filepath: str):
34
+ """Deletes the file after it has been sent."""
35
+ try:
36
+ if os.path.exists(filepath):
37
+ os.remove(filepath)
38
+ print(f"Deleted temporary file: {filepath}")
39
+ except Exception as e:
40
+ print(f"Error deleting file {filepath}: {e}")
41
+
42
+ def generate_audio(req: TTSRequest) -> str:
43
+ """Generates audio and returns the filename."""
44
+ filename = f"{req.channelID}-{req.username}-{req.messageid}.wav"
45
+ try:
46
+ audio_tensor = tts_model.generate(req.message, language_id=req.language)
47
+ ta.save(filename, audio_tensor, tts_model.sr)
48
+ return filename
49
+ except Exception as e:
50
+ raise HTTPException(status_code=500, detail=f"TTS Generation failed: {str(e)}")
51
+
52
+ @app.post("/tts")
53
+ async def tts_endpoint(req: TTSRequest, background_tasks: BackgroundTasks):
54
+ filename = generate_audio(req)
55
+ background_tasks.add_task(cleanup_file, filename)
56
+ return FileResponse(path=filename, filename=filename, media_type='audio/wav')
57
+
58
+ @app.post("/stream")
59
+ async def stream_endpoint(req: TTSRequest, background_tasks: BackgroundTasks):
60
+ filename = generate_audio(req)
61
+ background_tasks.add_task(cleanup_file, filename)
62
+ # FileResponse handles streaming efficiently for large files
63
+ return FileResponse(path=filename, media_type='audio/wav')
64
+
65
+ @app.post("/test")
66
+ async def test_endpoint(req: TTSRequest):
67
+ filename = generate_audio(req)
68
+ # For /test, we don't delete the file and just return "ok"
69
+ return {"status": "ok", "filename": filename}
70
+
71
+ if __name__ == "__main__":
72
+ port = int(os.environ.get("PORT", 7860))
73
+ uvicorn.run(app, host="0.0.0.0", port=port)
74
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.127.0
2
+ uvicorn==0.40.0
3
+ pydantic==2.11.10
4
+ chatterbox-tts==0.1.6
5
+ python-multipart==0.0.21
6
+ numpy==1.25.2
7
+ scipy==1.16.3
8
+ librosa==0.11.0
9
+ soundfile==0.13.1
10
+ aiofiles==24.1.0
11
+