Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from sqlalchemy import create_engine | |
| from sqlalchemy.orm import sessionmaker | |
| from s2smodels import Base, Audio_segment, AudioGeneration | |
| from pydub import AudioSegment | |
| import os | |
| from fastapi import FastAPI, Response | |
| import torch | |
| from fastapi.responses import JSONResponse | |
| from utils.prompt_making import make_prompt | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
| from utils.generation import SAMPLE_RATE, generate_audio, preload_models | |
| from io import BytesIO | |
| from pyannote.audio import Pipeline | |
| import soundfile as sf | |
| from fastapi_cors import CORS | |
| DATABASE_URL = "sqlite:///./sql_app.db" | |
| engine = create_engine(DATABASE_URL) | |
| Session = sessionmaker(bind=engine) | |
| app = FastAPI() | |
| """ | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORS, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| """ | |
| Base.metadata.create_all(engine) | |
| def root(): | |
| return {"message": "No result"} | |
| #add audio segements in Audio_segment Table | |
| def create_segment(start_time: float, end_time: float, audio: AudioSegment, type: str): | |
| session = Session() | |
| audio_bytes = BytesIO() | |
| audio.export(audio_bytes, format='wav') | |
| audio_bytes = audio_bytes.getvalue() | |
| segment = Audio_segment(start_time=start_time, end_time=end_time, type=type, audio=audio_bytes) | |
| session.add(segment) | |
| session.commit() | |
| session.close() | |
| return {"status_code": 200, "message": "success"} | |
| #add target audio to AudioGeneration Table | |
| def generate_target(audio: AudioSegment): | |
| session = Session() | |
| audio_bytes = BytesIO() | |
| audio.export(audio_bytes, format='wav') | |
| audio_bytes = audio_bytes.getvalue() | |
| target_audio = AudioGeneration(audio=audio_bytes) | |
| session.add(target_audio) | |
| session.commit() | |
| session.close() | |
| return {"status_code": 200, "message": "success"} | |
| """ | |
| audio segmentation into speech and non-speech using segmentation model | |
| """ | |
| def audio_speech_nonspeech_detection(audio_url): | |
| pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.0" | |
| ) | |
| diarization = pipeline(audio_url) | |
| speaker_regions=[] | |
| for turn, _,speaker in diarization.itertracks(yield_label=True): | |
| speaker_regions.append({"start":turn.start,"end":turn.end}) | |
| sound = AudioSegment.from_wav(audio_url) | |
| speaker_regions.sort(key=lambda x: x['start']) | |
| non_speech_regions = [] | |
| for i in range(1, len(speaker_regions)): | |
| start = speaker_regions[i-1]['end'] | |
| end = speaker_regions[i]['start'] | |
| if end > start: | |
| non_speech_regions.append({'start': start, 'end': end}) | |
| first_speech_start = speaker_regions[0]['start'] | |
| if first_speech_start > 0: | |
| non_speech_regions.insert(0, {'start': 0, 'end': first_speech_start}) | |
| last_speech_end = speaker_regions[-1]['end'] | |
| total_audio_duration = len(sound) | |
| if last_speech_end < total_audio_duration: | |
| non_speech_regions.append({'start': last_speech_end, 'end': total_audio_duration}) | |
| return speaker_regions,non_speech_regions | |
| """ | |
| save speech and non-speech segments in audio_segment table | |
| """ | |
| def split_audio_segments(audio_url): | |
| sound = AudioSegment.from_wav(audio_url) | |
| speech_segments, non_speech_segment = audio_speech_nonspeech_detection(audio_url) | |
| # Process speech segments | |
| for i, speech_segment in enumerate(speech_segments): | |
| start = int(speech_segment['start'] * 1000) | |
| end = int(speech_segment['end'] * 1000) | |
| segment = sound[start:end] | |
| create_segment(start_time=start/1000, | |
| end_time=end/1000, | |
| type="speech",audio=segment) | |
| # Process non-speech segments | |
| for i, non_speech_segment in enumerate(non_speech_segment): | |
| start = int(non_speech_segment['start'] * 1000) | |
| end = int(non_speech_segment['end'] * 1000) | |
| segment = sound[start:end] | |
| create_segment(start_time=start/1000, | |
| end_time=end/1000, | |
| type="non-speech",audio=segment) | |
| #@app.post("/translate_en_ar/") | |
| def en_text_to_ar_text_translation(text): | |
| pipe = pipeline("translation", model="facebook/nllb-200-distilled-600M") | |
| result=pipe(text,src_lang='English',tgt_lang='Egyptain Arabic') | |
| return result[0]['translation_text'] | |
| def make_prompt_audio(name,audio_path): | |
| make_prompt(name=name, audio_prompt_path=audio_path) | |
| # whisper model for speech to text process (english language) | |
| #@app.post("/en_speech_ar_text/") | |
| def en_speech_to_en_text_process(segment): | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model_id = "openai/whisper-large-v3" | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True) | |
| model.to(device) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| max_new_tokens=128, | |
| chunk_length_s=30, | |
| batch_size=16, | |
| return_timestamps=True, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| ) | |
| result = pipe(segment) | |
| return result["text"] | |
| #text to speech using VALL-E-X model | |
| #@app.post("/text_to_speech/") | |
| def text_to_speech(segment_id, target_text, audio_prompt): | |
| preload_models() | |
| session = Session() | |
| segment = session.query(Audio_segment).get(segment_id) | |
| make_prompt_audio(name=f"audio_{segment_id}",audio_path=audio_prompt) | |
| audio_array = generate_audio(target_text,f"audio_{segment_id}") | |
| temp_file = BytesIO() | |
| sf.write(temp_file, audio_array, SAMPLE_RATE, format='wav') | |
| temp_file.seek(0) | |
| segment.audio = temp_file.getvalue() | |
| session.commit() | |
| session.close() | |
| temp_file.close() | |
| #os.remove(temp_file) | |
| """ | |
| reconstruct target audio using all updated segment | |
| in audio_segment table and then remove all audio_Segment records | |
| """ | |
| def construct_audio(): | |
| session = Session() | |
| # Should be ordered by start_time | |
| segments = session.query(Audio_segment).order_by('start_time').all() | |
| audio_files = [] | |
| for segment in segments: | |
| audio_files.append(AudioSegment.from_file(BytesIO(segment.audio), format='wav')) | |
| target_audio = sum(audio_files, AudioSegment.empty()) | |
| generate_target(audio=target_audio) | |
| # Delete all records in Audio_segment table | |
| session.query(Audio_segment).delete() | |
| session.commit() | |
| session.close() | |
| """ | |
| source => english speech | |
| target => arabic speeech | |
| """ | |
| #@app.post("/en_speech_ar_speech/") | |
| def speech_to_speech_translation_en_ar(audio_url): | |
| session=Session() | |
| target_text=None | |
| split_audio_segments(audio_url) | |
| #filtering by type | |
| speech_segments = session.query(Audio_segment).filter(Audio_segment.type == "speech").all() | |
| for segment in speech_segments: | |
| audio_data = segment.audio | |
| text = en_speech_to_en_text_process(audio_data) | |
| if text: | |
| target_text=en_text_to_ar_text_translation(text) | |
| else: | |
| print("speech_to_text_process function not return result. ") | |
| if target_text is None: | |
| print("Target text is None.") | |
| else: | |
| segment_id = segment.id | |
| segment_duration = segment.end_time - segment.start_time | |
| if segment_duration <=15: | |
| text_to_speech(segment_id,target_text,segment.audio) | |
| else: | |
| audio_data=extract_15_seconds(segment.audio,segment.start_time,segment.end_time) | |
| text_to_speech(segment_id,target_text,audio_data) | |
| os.remove(audio_data) | |
| construct_audio() | |
| return JSONResponse(status_code=200, content={"status_code":"succcessfully"}) | |
| async def get_ar_audio(audio_url): | |
| speech_to_speech_translation_en_ar(audio_url) | |
| session = Session() | |
| # Get target audio from AudioGeneration | |
| target_audio = session.query(AudioGeneration).order_by(AudioGeneration.id.desc()).first() | |
| # Remove target audio from database | |
| #session.query(AudioGeneration).delete() | |
| #session.commit() | |
| #session.close() | |
| if target_audio is None: | |
| raise ValueError("No audio found in the database") | |
| audio_bytes = target_audio.audio | |
| return Response(content=audio_bytes, media_type="audio/wav") | |
| # speech to speech from arabic to english processes | |
| #@app.post("/ar_speech_to_en_text/") | |
| def ar_speech_to_ar_text_process(segment): | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model_id = "openai/whisper-large-v3" | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True) | |
| model.to(device) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| max_new_tokens=128, | |
| chunk_length_s=30, | |
| batch_size=16, | |
| return_timestamps=True, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| ) | |
| result = pipe(segment,generate_kwargs={"language": "arabic"}) | |
| return result["text"] | |
| #@app.post("/ar_translate/") | |
| def ar_text_to_en_text_translation(text): | |
| pipe = pipeline("translation", model="facebook/nllb-200-distilled-600M") | |
| result=pipe(text,src_lang='Egyptain Arabic',tgt_lang='English') | |
| return result[0]['translation_text'] | |
| """ | |
| source => arabic speech | |
| target => english speeech | |
| """ | |
| def speech_to_speech_translation_ar_en(audio_url): | |
| session=Session() | |
| target_text=None | |
| split_audio_segments(audio_url) | |
| #filtering by type | |
| speech_segments = session.query(Audio_segment).filter(Audio_segment.type == "speech").all() | |
| for segment in speech_segments: | |
| audio_data = segment.audio | |
| text = ar_speech_to_ar_text_process(audio_data) | |
| if text: | |
| target_text=ar_text_to_en_text_translation(text) | |
| else: | |
| print("speech_to_text_process function not return result. ") | |
| if target_text is None: | |
| print("Target text is None.") | |
| else: | |
| segment_id = segment.id | |
| segment_duration = segment.end_time - segment.start_time | |
| if segment_duration <=15: | |
| text_to_speech(segment_id,target_text,segment.audio) | |
| else: | |
| audio_data=extract_15_seconds(segment.audio,segment.start_time,segment.end_time) | |
| text_to_speech(segment_id,target_text,audio_data) | |
| os.remove(audio_data) | |
| construct_audio() | |
| return JSONResponse(status_code=200, content={"status_code":"succcessfully"}) | |
| async def get_en_audio(audio_url): | |
| speech_to_speech_translation_ar_en(audio_url) | |
| session = Session() | |
| # Get target audio from AudioGeneration | |
| target_audio = session.query(AudioGeneration).order_by(AudioGeneration.id.desc()).first() | |
| # Remove target audio from database | |
| #session.query(AudioGeneration).delete() | |
| #session.commit() | |
| #session.close() | |
| if target_audio is None: | |
| raise ValueError("No audio found in the database") | |
| audio_bytes = target_audio.audio | |
| return Response(content=audio_bytes, media_type="audio/wav") | |
| def get_all_audio_segments(): | |
| session=Session() | |
| segments = session.query(Audio_segment).all() | |
| segment_dicts = [] | |
| for segment in segments: | |
| if segment.audio is None: | |
| raise ValueError("No audio found in the database") | |
| audio_bytes = segment.audio | |
| file_path = f"segments//segment{segment.id}_audio.wav" | |
| with open(file_path, "wb") as file: | |
| file.write(audio_bytes) | |
| segment_dicts.append({ | |
| "id": segment.id, | |
| "start_time": segment.start_time, | |
| "end_time": segment.end_time, | |
| "type": segment.type, | |
| "audio_url":file_path | |
| }) | |
| session.close() | |
| return {"segments":segment_dicts} | |
| def extract_15_seconds(audio_data, start_time, end_time): | |
| audio_segment = AudioSegment.from_file(BytesIO(audio_data), format='wav') | |
| start_ms = start_time * 1000 | |
| end_ms = min((start_time + 15) * 1000, end_time * 1000) | |
| extracted_segment = audio_segment[start_ms:end_ms] | |
| temp_wav_path = "temp.wav" | |
| extracted_segment.export(temp_wav_path, format="wav") | |
| return temp_wav_path | |