text_ / app.py
UDface11jkj's picture
Update app.py
7f7ab2e verified
import os
import uuid
import torch
import soundfile as sf
import uvicorn
from fastapi import FastAPI, Form, Request, HTTPException
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from model import Dia
from audio import *
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# Check if the model loads successfully
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
os.makedirs("static/audio", exist_ok=True)
model = Dia.from_pretrained(
"nari-labs/Dia-1.6B",
compute_dtype="float32", # Use float32 if float16 causes issues
device=device,
)
if device == "cpu":
model = model.eval()
torch.set_num_threads(4)
print("Model loaded successfully with optimizations")
except Exception as e:
print(f"Error loading Dia model: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error loading Dia model: {str(e)}")
@app.get('/')
async def index(request: Request):
try:
return templates.TemplateResponse("index.html", {'request': request})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error loading template: {str(e)}")
@app.post("/convertor")
async def process(request: Request, paragraph: str = Form(...), action: str = Form(...)):
try:
if not paragraph:
raise HTTPException(status_code=400, detail="Text is required")
if action == "audio":
print(f"Generating audio for: {paragraph}")
output = model.generate(paragraph)
print(f"Generated output type: {type(output)}")
# Ensure output is a valid waveform
if isinstance(output, torch.Tensor):
output = output.cpu().numpy() # Convert to numpy array if it's a tensor
if output.ndim != 1:
raise HTTPException(status_code=400, detail="Output is not a valid 1D audio array.")
filename = f"audio_{uuid.uuid4()}.wav"
filepath = f"static/audio/{filename}"
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# Save audio file
sf.write(filepath, output, 44100) # Save audio file
return FileResponse(filepath, media_type="audio/wav", filename=filename)
elif action == "summarize":
raise HTTPException(status_code=400, detail="Summarization not implemented")
raise HTTPException(status_code=400, detail="Invalid action")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)