Mohamed7733 commited on
Commit
ae52466
·
verified ·
1 Parent(s): 77480d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -40
app.py CHANGED
@@ -1,53 +1,125 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
4
  import torch
5
- import numpy as np
6
- import scipy.io.wavfile as wav
7
-
8
- # Use a TTS model like 'espnet/kan-bayashi_ljspeech_tts'
9
- model_name = "espnet/kan-bayashi_ljspeech_tts" # Change to a valid TTS model
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
 
13
  # Initialize FastAPI app
14
  app = FastAPI()
15
 
16
- # Function to convert text to speech
17
- def text_to_speech(text: str):
18
- # Convert text to model format
19
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
20
-
21
- # Generate the speech
22
- with torch.no_grad():
23
- output = model.generate(**inputs)
24
-
25
- # Convert the output to a numpy array (audio waveform)
26
- waveform = output.numpy().squeeze()
27
 
28
- # Normalize the audio to make it clearer
29
- waveform = waveform / np.max(np.abs(waveform)) # Normalize to range [-1, 1]
 
 
 
 
 
 
 
30
 
31
- # Save the audio to a file
32
- file_path = "/tmp/output.wav"
33
- wav.write(file_path, 16000, (waveform * 32767).astype(np.int16)) # Convert to 16-bit PCM
 
 
 
 
 
 
 
34
 
35
- return file_path
 
 
 
36
 
37
- # Define request model
38
- class TextRequest(BaseModel):
39
- text: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # API endpoint to convert text to speech
42
- @app.post("/text_to_speech/")
43
- async def convert_text_to_speech(request: TextRequest):
44
- text = request.text
45
 
46
- # Generate speech from text
47
- file_path = text_to_speech(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Return the audio file as response
50
- with open(file_path, "rb") as f:
51
- audio_data = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- return {"audio": audio_data}
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import subprocess
4
+ from fastapi import FastAPI, UploadFile, File
5
+ import whisper
6
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
7
  import torch
8
+ from datetime import timedelta
9
+ from deep_translator import GoogleTranslator
10
+ import ffmpeg
 
 
 
 
11
 
12
  # Initialize FastAPI app
13
  app = FastAPI()
14
 
15
+ def format_time(seconds):
16
+ # Convert seconds to SRT format (00:00:00,000)
17
+ td = timedelta(seconds=seconds)
18
+ hours, remainder = divmod(td.seconds, 3600)
19
+ minutes, seconds = divmod(remainder, 60)
20
+ milliseconds = td.microseconds // 1000
21
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
 
 
 
 
22
 
23
+ def extract_audio(video_path):
24
+ # Extract audio from video using ffmpeg
25
+ temp_dir = tempfile.gettempdir()
26
+ audio_path = os.path.join(temp_dir, "extracted_audio.wav")
27
+
28
+ # Use ffmpeg to extract audio
29
+ ffmpeg.input(video_path).output(audio_path, format='wav').run()
30
+
31
+ return audio_path
32
 
33
+ def transcribe_audio(audio_path):
34
+ # Transcribe audio to text using Whisper model
35
+ try:
36
+ # Load the Whisper model
37
+ model = whisper.load_model("base") # Load the Whisper model
38
+ result = model.transcribe(audio_path)
39
+ return result["segments"]
40
+ except Exception as e:
41
+ print(f"Error using whisper model: {e}")
42
+ return []
43
 
44
+ def translate_text(text):
45
+ # Translate text from English to Arabic
46
+ translator = GoogleTranslator(source='en', target='ar')
47
+ return translator.translate(text)
48
 
49
+ def create_srt(segments, output_path):
50
+ # Create an SRT file from translated segments ensuring proper encoding
51
+ with open(output_path, 'w', encoding='utf-8-sig') as srt_file: # UTF-8 with BOM for compatibility
52
+ for i, segment in enumerate(segments, start=1):
53
+ if hasattr(segment, 'get'): # Handle variations in output models
54
+ start_time = segment.get('start', 0)
55
+ end_time = segment.get('end', 0)
56
+ text = segment.get('text', '')
57
+ translation = segment.get('translation', '')
58
+ else:
59
+ start_time = segment.start
60
+ end_time = segment.end
61
+ text = segment.text
62
+ translation = getattr(segment, 'translation', text) # Use the original text if no translation
63
+
64
+ # Fixed the string formatting error
65
+ srt_file.write(f"{i}\n")
66
+ srt_file.write(f"{format_time(start_time)} --> {format_time(end_time)}\n")
67
+ srt_file.write(f"{translation}\n\n")
68
 
69
+ def burn_subtitles(video_path, srt_path, output_path):
70
+ # Burn subtitles into video using FFmpeg with Arabic support
71
+ font_path = "/usr/share/fonts/truetype/Amiri-Regular.ttf" # Path to Amiri font
 
72
 
73
+ cmd = [
74
+ 'ffmpeg', '-y',
75
+ '-i', video_path,
76
+ '-vf', f"subtitles='{srt_path}':force_style='FontName={font_path},FontSize=24,PrimaryColour=&HFFFFFF,OutlineColour=&H000000,BorderStyle=3,Alignment=2,Encoding=1'",
77
+ '-sub_charenc', 'UTF-8',
78
+ '-c:v', 'libx264', '-crf', '18',
79
+ '-c:a', 'copy',
80
+ output_path
81
+ ]
82
+
83
+ try:
84
+ subprocess.run(cmd, check=True)
85
+ return output_path
86
+ except subprocess.CalledProcessError as e:
87
+ print(f"FFmpeg error: {e}")
88
+ return None
89
 
90
+ def process_video(video_path):
91
+ # Process the video: extract audio, transcribe, translate, create SRT, burn subtitles
92
+ temp_dir = tempfile.gettempdir()
93
+ file_name = os.path.splitext(os.path.basename(video_path))[0]
94
+
95
+ audio_path = extract_audio(video_path)
96
+ segments = transcribe_audio(audio_path)
97
+
98
+ translated_segments = []
99
+ for i, segment in enumerate(segments):
100
+ text = segment.text if hasattr(segment, 'text') else segment.get('text', '')
101
+ translated_text = translate_text(text)
102
+ segment.translation = translated_text
103
+ translated_segments.append(segment)
104
+
105
+ srt_path = os.path.join(temp_dir, f"{file_name}.srt")
106
+ create_srt(translated_segments, srt_path)
107
+
108
+ output_path = os.path.join(temp_dir, f"{file_name}_translated.mp4")
109
+ result_path = burn_subtitles(video_path, srt_path, output_path)
110
+
111
+ return result_path, srt_path
112
 
113
+ # API endpoint to process video
114
+ @app.post("/process_video/")
115
+ async def process_video_endpoint(file: UploadFile = File(...)):
116
+ # API to process video and generate translated subtitles
117
+ temp_dir = tempfile.gettempdir()
118
+ file_path = os.path.join(temp_dir, file.filename)
119
+
120
+ with open(file_path, "wb") as f:
121
+ f.write(await file.read())
122
+
123
+ result_path, srt_path = process_video(file_path)
124
+
125
+ return {"video_url": result_path, "srt_url": srt_path}