| import os |
| import time |
| from pydantic import BaseModel |
| from fastapi import FastAPI, HTTPException, Query, Request |
| from fastapi.responses import FileResponse |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from langchain.chains import LLMChain |
| from langchain.prompts import PromptTemplate |
| from TextGen.suno import custom_generate_audio, get_audio_information |
| from langchain_google_genai import ( |
| ChatGoogleGenerativeAI, |
| HarmBlockThreshold, |
| HarmCategory, |
| ) |
| from TextGen import app |
| from gradio_client import Client |
|
|
| song_base_api=os.environ["VERCEL_API"] |
|
|
| my_hf_token=os.environ["HF_TOKEN"] |
|
|
| tts_client = Client("https://jofthomas-xtts.hf.space/",hf_token=my_hf_token) |
|
|
|
|
|
|
| class Generate(BaseModel): |
| text:str |
|
|
| def generate_text(prompt: str): |
| if prompt == "": |
| return {"detail": "Please provide a prompt."} |
| else: |
| prompt = PromptTemplate(template=prompt, input_variables=['Prompt']) |
|
|
| |
| llm = ChatGoogleGenerativeAI( |
| model="gemini-pro", |
| safety_settings={ |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
| }, |
| ) |
|
|
| llmchain = LLMChain( |
| prompt=prompt, |
| llm=llm |
| ) |
|
|
| llm_response = llmchain.run({"Prompt": prompt}) |
| return Generate(text=llm_response) |
|
|
| |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| @app.get("/", tags=["Home"]) |
| def api_home(): |
| return {'detail': 'Welcome to FastAPI TextGen Tutorial!'} |
|
|
| @app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate) |
| def inference(input_prompt: str): |
| return generate_text(prompt=input_prompt) |
|
|
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.responses import FileResponse |
| import json |
|
|
| app = FastAPI() |
|
|
| @app.get("/generate_wav") |
| async def generate_wav(request: Request): |
| try: |
| body = await request.json() |
| text = body.get("text") |
| language = body.get("language", "en") |
| voice_choice = body.get("voice_choice", "./blacksmith.mp3") |
|
|
| valid_voices = [ |
| "./blacksmith.mp3", |
| "./female.wav", |
| "./female.mp3", |
| "./narator_out.wav", |
| "./blacksmith2.mp3" |
| ] |
|
|
| if voice_choice not in valid_voices: |
| raise HTTPException(status_code=400, detail="Invalid voice choice") |
|
|
| |
| result = tts_client.predict( |
| text, |
| language, |
| voice_choice, |
| voice_choice, |
| False, |
| False, |
| False, |
| True, |
| fn_index=1 |
| ) |
|
|
| |
| wav_file_path = result[1] |
|
|
| |
| return FileResponse(wav_file_path, media_type="audio/wav", filename="output.wav") |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.get("/generate_song") |
| async def generate_song(text: str): |
| try: |
| data = custom_generate_audio({ |
| "prompt": f"{text}", |
| "make_instrumental": False, |
| "wait_audio": False |
| }) |
| ids = f"{data[0]['id']},{data[1]['id']}" |
| print(f"ids: {ids}") |
|
|
| for _ in range(60): |
| data = get_audio_information(ids) |
| if data[0]["status"] == 'streaming': |
| print(f"{data[0]['id']} ==> {data[0]['audio_url']}") |
| print(f"{data[1]['id']} ==> {data[1]['audio_url']}") |
| break |
| |
| time.sleep(5) |
| except: |
| print("Error") |