Spaces:
Build error
Build error
| import os | |
| import uuid | |
| import torch | |
| import soundfile as sf | |
| import uvicorn | |
| from fastapi import FastAPI, Form, Request, HTTPException | |
| from fastapi.responses import HTMLResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from model import Dia | |
| from audio import * | |
| app = FastAPI() | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # Check if the model loads successfully | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| os.makedirs("static/audio", exist_ok=True) | |
| model = Dia.from_pretrained( | |
| "nari-labs/Dia-1.6B", | |
| compute_dtype="float32", # Use float32 if float16 causes issues | |
| device=device, | |
| ) | |
| if device == "cpu": | |
| model = model.eval() | |
| torch.set_num_threads(4) | |
| print("Model loaded successfully with optimizations") | |
| except Exception as e: | |
| print(f"Error loading Dia model: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error loading Dia model: {str(e)}") | |
| async def index(request: Request): | |
| try: | |
| return templates.TemplateResponse("index.html", {'request': request}) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error loading template: {str(e)}") | |
| async def process(request: Request, paragraph: str = Form(...), action: str = Form(...)): | |
| try: | |
| if not paragraph: | |
| raise HTTPException(status_code=400, detail="Text is required") | |
| if action == "audio": | |
| print(f"Generating audio for: {paragraph}") | |
| output = model.generate(paragraph) | |
| print(f"Generated output type: {type(output)}") | |
| # Ensure output is a valid waveform | |
| if isinstance(output, torch.Tensor): | |
| output = output.cpu().numpy() # Convert to numpy array if it's a tensor | |
| if output.ndim != 1: | |
| raise HTTPException(status_code=400, detail="Output is not a valid 1D audio array.") | |
| filename = f"audio_{uuid.uuid4()}.wav" | |
| filepath = f"static/audio/{filename}" | |
| os.makedirs(os.path.dirname(filepath), exist_ok=True) | |
| # Save audio file | |
| sf.write(filepath, output, 44100) # Save audio file | |
| return FileResponse(filepath, media_type="audio/wav", filename=filename) | |
| elif action == "summarize": | |
| raise HTTPException(status_code=400, detail="Summarization not implemented") | |
| raise HTTPException(status_code=400, detail="Invalid action") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |