import os import uuid import torch import soundfile as sf import uvicorn from fastapi import FastAPI, Form, Request, HTTPException from fastapi.responses import HTMLResponse, FileResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from model import Dia from audio import * app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") # Check if the model loads successfully try: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") os.makedirs("static/audio", exist_ok=True) model = Dia.from_pretrained( "nari-labs/Dia-1.6B", compute_dtype="float32", # Use float32 if float16 causes issues device=device, ) if device == "cpu": model = model.eval() torch.set_num_threads(4) print("Model loaded successfully with optimizations") except Exception as e: print(f"Error loading Dia model: {str(e)}") raise HTTPException(status_code=500, detail=f"Error loading Dia model: {str(e)}") @app.get('/') async def index(request: Request): try: return templates.TemplateResponse("index.html", {'request': request}) except Exception as e: raise HTTPException(status_code=500, detail=f"Error loading template: {str(e)}") @app.post("/convertor") async def process(request: Request, paragraph: str = Form(...), action: str = Form(...)): try: if not paragraph: raise HTTPException(status_code=400, detail="Text is required") if action == "audio": print(f"Generating audio for: {paragraph}") output = model.generate(paragraph) print(f"Generated output type: {type(output)}") # Ensure output is a valid waveform if isinstance(output, torch.Tensor): output = output.cpu().numpy() # Convert to numpy array if it's a tensor if output.ndim != 1: raise HTTPException(status_code=400, detail="Output is not a valid 1D audio array.") filename = f"audio_{uuid.uuid4()}.wav" filepath = f"static/audio/{filename}" os.makedirs(os.path.dirname(filepath), exist_ok=True) # Save audio file sf.write(filepath, output, 44100) # Save audio file return FileResponse(filepath, media_type="audio/wav", filename=filename) elif action == "summarize": raise HTTPException(status_code=400, detail="Summarization not implemented") raise HTTPException(status_code=400, detail="Invalid action") except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)