File size: 2,855 Bytes
2fb1978
 
 
 
 
feb0732
 
 
 
b44c041
7f7ab2e
feb0732
 
 
 
 
25caad3
feb0732
 
 
 
 
25caad3
feb0732
 
25caad3
 
feb0732
25caad3
feb0732
25caad3
 
 
feb0732
 
 
25caad3
feb0732
 
 
25caad3
 
 
 
feb0732
 
 
 
 
 
 
 
25caad3
feb0732
25caad3
 
 
 
 
 
 
 
 
feb0732
 
25caad3
 
2fb1978
25caad3
 
feb0732
25caad3
feb0732
 
25caad3
 
feb0732
 
25caad3
feb0732
27ac591
25caad3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import uuid
import torch
import soundfile as sf
import uvicorn
from fastapi import FastAPI, Form, Request, HTTPException
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from model import Dia
from audio import *

app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

# Check if the model loads successfully
try:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    os.makedirs("static/audio", exist_ok=True)

    model = Dia.from_pretrained(
        "nari-labs/Dia-1.6B",
        compute_dtype="float32",  # Use float32 if float16 causes issues
        device=device,
    )

    if device == "cpu":
        model = model.eval()
        torch.set_num_threads(4)

    print("Model loaded successfully with optimizations")
except Exception as e:
    print(f"Error loading Dia model: {str(e)}")
    raise HTTPException(status_code=500, detail=f"Error loading Dia model: {str(e)}")

@app.get('/')
async def index(request: Request):
    try:
        return templates.TemplateResponse("index.html", {'request': request})
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error loading template: {str(e)}")

@app.post("/convertor")
async def process(request: Request, paragraph: str = Form(...), action: str = Form(...)):
    try:
        if not paragraph:
            raise HTTPException(status_code=400, detail="Text is required")

        if action == "audio":
            print(f"Generating audio for: {paragraph}")
            output = model.generate(paragraph)
            print(f"Generated output type: {type(output)}")

            # Ensure output is a valid waveform
            if isinstance(output, torch.Tensor):
                output = output.cpu().numpy()  # Convert to numpy array if it's a tensor

            if output.ndim != 1:
                raise HTTPException(status_code=400, detail="Output is not a valid 1D audio array.")

            filename = f"audio_{uuid.uuid4()}.wav"
            filepath = f"static/audio/{filename}"
            os.makedirs(os.path.dirname(filepath), exist_ok=True)

            # Save audio file
            sf.write(filepath, output, 44100)  # Save audio file

            return FileResponse(filepath, media_type="audio/wav", filename=filename)

        elif action == "summarize":
            raise HTTPException(status_code=400, detail="Summarization not implemented")

        raise HTTPException(status_code=400, detail="Invalid action")

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)