| import os |
| import time |
| from langchain_core.pydantic_v1 import BaseModel, Field |
| from fastapi import FastAPI, HTTPException, Query, Request |
| from fastapi.responses import FileResponse |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from langchain.chains import LLMChain |
| from langchain.prompts import PromptTemplate |
| from TextGen.suno import custom_generate_audio, get_audio_information |
| from langchain_google_genai import ( |
| ChatGoogleGenerativeAI, |
| HarmBlockThreshold, |
| HarmCategory, |
| ) |
| from TextGen import app |
| from gradio_client import Client, handle_file |
| from typing import List |
|
|
| class PlayLastMusic(BaseModel): |
| '''plays the lastest created music ''' |
| Desicion: str = Field( |
| ..., description="Yes or No" |
| ) |
|
|
| class CreateLyrics(BaseModel): |
| f'''create some Lyrics for a new music''' |
| Desicion: str = Field( |
| ..., description="Yes or No" |
| ) |
|
|
| class CreateNewMusic(BaseModel): |
| f'''create a new music with the Lyrics previously computed''' |
| Name: str = Field( |
| ..., description="tags to describe the new music" |
| ) |
|
|
|
|
|
|
| class Message(BaseModel): |
| npc: str | None = None |
| messages: List[str] | None = None |
| |
| class VoiceMessage(BaseModel): |
| npc: str | None = None |
| input: str | None = None |
| language: str | None = "en" |
| genre:str | None = "Male" |
| |
| song_base_api=os.environ["VERCEL_API"] |
|
|
| my_hf_token=os.environ["HF_TOKEN"] |
|
|
| tts_client = Client("Jofthomas/xtts",hf_token=my_hf_token) |
|
|
| main_npcs={ |
| "Blacksmith":"./voices/Blacksmith.mp3", |
| "Herbalist":"./voices/female.mp3", |
| "Bard":"./voices/Bard_voice.mp3" |
| } |
| main_npc_system_prompts={ |
| "Blacksmith":"You are a blacksmith in a video game", |
| "Herbalist":"You are an herbalist in a video game", |
| "Bard":"You are a bard in a video game" |
| } |
| class Generate(BaseModel): |
| text:str |
|
|
| def generate_text(messages: List[str], npc:str): |
| print(npc) |
| if npc in main_npcs: |
| system_prompt=main_npc_system_prompts[npc] |
| else: |
| system_prompt="you're a character in a video game. Play along." |
| print(system_prompt) |
| new_messages=[{"role": "user", "content": system_prompt}] |
| for index, message in enumerate(messages): |
| if index%2==0: |
| new_messages.append({"role": "user", "content": message}) |
| else: |
| new_messages.append({"role": "assistant", "content": message}) |
| print(new_messages) |
| |
| llm = ChatGoogleGenerativeAI( |
| model="gemini-1.5-pro-latest", |
| max_output_tokens=100, |
| temperature=1, |
| safety_settings={ |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
| HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, |
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, |
| HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE |
| }, |
| ) |
| if npc=="bard": |
| llm = llm.bind_tools([PlayLastMusic,CreateNewMusic,CreateLyrics]) |
|
|
| llm_response = llm.invoke(new_messages) |
| print(llm_response) |
| return Generate(text=llm_response.content) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| @app.get("/", tags=["Home"]) |
| def api_home(): |
| return {'detail': 'Everchanging Quest backend, nothing to see here'} |
|
|
| @app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate) |
| def inference(message: Message): |
| return generate_text(messages=message.messages, npc=message.npc) |
|
|
| |
| def determine_vocie_from_npc(npc,genre): |
| if npc in main_npcs: |
| return main_npcs[npc] |
| else: |
| if genre =="Male": |
| "./voices/default_male.mp3" |
| if genre=="Female": |
| return"./voices/default_female.mp3" |
| else: |
| return "./voices/narator_out.wav" |
| |
| @app.post("/generate_wav") |
| async def generate_wav(message:VoiceMessage): |
| try: |
| voice=determine_vocie_from_npc(message.npc, message.genre) |
| |
| result = tts_client.predict( |
| prompt=message.input, |
| language=message.language, |
| audio_file_pth=handle_file(voice), |
| mic_file_path=None, |
| use_mic=False, |
| voice_cleanup=False, |
| no_lang_auto_detect=False, |
| agree=True, |
| api_name="/predict" |
| ) |
|
|
| |
| wav_file_path = result |
|
|
| |
| return FileResponse(wav_file_path, media_type="audio/wav", filename="output.wav") |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.get("/generate_song") |
| async def generate_song(text: str): |
| try: |
| data = custom_generate_audio({ |
| "prompt": f"{text}", |
| "make_instrumental": False, |
| "wait_audio": False |
| }) |
| ids = f"{data[0]['id']},{data[1]['id']}" |
| print(f"ids: {ids}") |
|
|
| for _ in range(60): |
| data = get_audio_information(ids) |
| if data[0]["status"] == 'streaming': |
| print(f"{data[0]['id']} ==> {data[0]['audio_url']}") |
| print(f"{data[1]['id']} ==> {data[1]['audio_url']}") |
| break |
| |
| time.sleep(5) |
| except: |
| print("Error") |