AlserFurma commited on
Commit
bbf823e
·
verified ·
1 Parent(s): 2211843

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -111
app.py CHANGED
@@ -4,82 +4,109 @@ from PIL import Image
4
  import tempfile
5
  from gradio_client import Client, handle_file
6
  import torch
7
- from transformers import VitsModel, AutoTokenizer
8
  import scipy.io.wavfile as wavfile
 
 
 
 
 
 
9
 
10
- # Загрузка обновленной TTS модели при старте
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {device}")
13
 
14
  try:
15
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-rus").to(device)
16
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
17
- print("TTS модель загружена успешно!")
 
 
 
 
 
 
 
 
 
 
18
  except Exception as e:
19
- raise RuntimeError(f"Ошибка загрузки TTS модели: {str(e)}")
 
 
 
 
 
20
 
21
- # Пространство для talking-head
22
  TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head"
23
 
 
 
 
 
 
24
  def inference(image: Image.Image, text: str):
 
25
  error_msg = ""
26
  video_path = None
27
  audio_path = None
28
  img_path = None
29
-
30
  try:
31
- # Валидация входных данных
32
  if image is None:
33
  raise ValueError("Загрузите изображение лектора!")
34
-
35
  if not text or not text.strip():
36
  raise ValueError("Введите текст лекции!")
37
-
38
  if len(text) > 500:
39
- raise ValueError("Текст слишком длинный! Используйте до 500 символов.")
40
-
41
- print(f"Генерация TTS для текста: '{text[:50]}...'")
42
-
43
- # Шаг 1: Генерация аудио через TTS
44
- torch.manual_seed(42)
45
- inputs = tts_tokenizer(text, return_tensors="pt").to(device)
46
-
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  with torch.no_grad():
48
  output = tts_model(**inputs)
49
- waveform = output.waveform.squeeze().cpu().numpy()
50
-
51
- if waveform.size == 0:
52
- raise ValueError("TTS сгенерировал пустое аудио! Попробуйте другой текст.")
53
-
54
- # Конвертация в int16 для WAV
55
  audio = (waveform * 32767).astype("int16")
56
  sampling_rate = tts_model.config.sampling_rate
57
-
58
- # Сохранение аудио
59
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
60
- wavfile.write(audio_file.name, sampling_rate, audio)
61
- audio_path = audio_file.name
62
-
63
- print(f"TTS аудио сохранено: {audio_path} (длина: {len(waveform)/sampling_rate:.1f} сек)")
64
-
65
- # Шаг 2: Сохранение изображения
66
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_file:
67
- # Конвертация в RGB если нужно
68
- if image.mode != 'RGB':
69
- image = image.convert('RGB')
70
- image.save(img_file.name, format='PNG')
71
- img_path = img_file.name
72
-
73
- print(f"Изображение сохранено: {img_path}")
74
-
75
- # Шаг 3: Вызов talking-head API
76
- print(f"Подключение к {TALKING_HEAD_SPACE}...")
77
  client = Client(TALKING_HEAD_SPACE)
78
-
79
- # Проверяем доступные API endpoints
80
- print("Доступные API методы:", client.view_api())
81
-
82
- # Вызов API с правильными параметрами
83
  result = client.predict(
84
  image_path=handle_file(img_path),
85
  audio_path=handle_file(audio_path),
@@ -87,87 +114,63 @@ def inference(image: Image.Image, text: str):
87
  steps=10,
88
  api_name="/process_image_audio"
89
  )
90
-
91
- print(f"Результат API: {type(result)}")
92
-
93
- # Обработка результата
94
- if isinstance(result, tuple) and len(result) > 0:
95
- video_data = result[0]
96
- if isinstance(video_data, dict) and 'video' in video_data:
97
- video_path = video_data['video']
98
- elif isinstance(video_data, dict) and 'path' in video_data:
99
- video_path = video_data['path']
100
- elif isinstance(video_data, str):
101
- video_path = video_data
102
- else:
103
- video_path = video_data
104
  else:
105
- video_path = result
106
-
107
- print(f"Видео сгенерировано: {video_path}")
108
- error_msg = "✅ Видео успешно сгенерировано!"
109
-
110
  except Exception as e:
111
  error_msg = f"❌ Ошибка: {str(e)}"
112
- print(f"ОШИБКА: {error_msg}")
113
- import traceback
114
  traceback.print_exc()
115
-
116
  finally:
117
- # Очистка временных файлов
118
- if audio_path and os.path.exists(audio_path):
119
- try:
120
- os.remove(audio_path)
121
- print(f"Удален временный файл: {audio_path}")
122
- except:
123
- pass
124
-
125
- if img_path and os.path.exists(img_path):
126
- try:
127
- os.remove(img_path)
128
- print(f"Удален временный файл: {img_path}")
129
- except:
130
- pass
131
-
132
  return video_path, error_msg
133
 
134
- # Интерфейс Gradio
135
- title = "Видео-лектор с TTS (Русский)"
 
 
 
 
 
136
  description = """
137
- Загрузите фото лектора и введите текст лекции.
138
- Система сгенерирует видео, где лектор "произносит" ваш текст!
139
- **Требования:**
140
- - Фото: фронтальное изображение лица
141
- - Текст: до 500 символов на русском языке
142
- """
143
 
144
- examples = [
145
- [
146
- "example_image.png",
147
- "Добрый день! Сегодня мы рассмотрим основы машинного обучения."
148
- ]
149
- ]
150
 
151
  iface = gr.Interface(
152
  fn=inference,
153
  inputs=[
154
- gr.Image(type="pil", label="📸 Фото лектора"),
155
  gr.Textbox(
156
- lines=5,
157
- placeholder="Введите текст лекции на русском языке (до 500 символов)...",
158
- label="📝 Текст лекции"
159
  )
160
  ],
161
  outputs=[
162
- gr.Video(label="🎬 Готовое видео"),
163
- gr.Textbox(label="ℹ️ Статус", interactive=False)
164
  ],
165
  title=title,
166
  description=description,
167
- flagging_mode="never",
168
- examples=None, # Добавьте примеры, если есть тестовые изображения
169
- cache_examples=False
170
  )
171
 
172
  if __name__ == "__main__":
173
- iface.launch()
 
4
  import tempfile
5
  from gradio_client import Client, handle_file
6
  import torch
7
+ from transformers import VitsModel, AutoTokenizer, pipeline
8
  import scipy.io.wavfile as wavfile
9
+ import traceback
10
+
11
+
12
+ # =========================
13
+ # Загрузка моделей
14
+ # =========================
15
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  print(f"Using device: {device}")
18
 
19
  try:
20
+ # TTS модель казахского языка
21
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-kaz").to(device)
22
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz")
23
+
24
+ # Модель перевода ru -> kk
25
+ translator = pipeline(
26
+ "translation",
27
+ model="facebook/nllb-200-distilled-600M",
28
+ device=0 if device == "cuda" else -1
29
+ )
30
+
31
+ print("✅ Все модели успешно загружены!")
32
+
33
  except Exception as e:
34
+ raise RuntimeError(f"Ошибка загрузки моделей: {str(e)}")
35
+
36
+
37
+ # =========================
38
+ # Talking Head Space
39
+ # =========================
40
 
 
41
  TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head"
42
 
43
+
44
+ # =========================
45
+ # Основная функция
46
+ # =========================
47
+
48
  def inference(image: Image.Image, text: str):
49
+
50
  error_msg = ""
51
  video_path = None
52
  audio_path = None
53
  img_path = None
54
+
55
  try:
56
+ # Проверки
57
  if image is None:
58
  raise ValueError("Загрузите изображение лектора!")
59
+
60
  if not text or not text.strip():
61
  raise ValueError("Введите текст лекции!")
62
+
63
  if len(text) > 500:
64
+ raise ValueError("Текст превышает 500 символов!")
65
+
66
+ print("Ввод (RU):", text)
67
+
68
+ # =========================
69
+ # Шаг 1 — Перевод
70
+ # =========================
71
+ translation = translator(
72
+ text,
73
+ src_lang="rus_Cyrl",
74
+ tgt_lang="kaz_Cyrl"
75
+ )
76
+
77
+ translated_text = translation[0]["translation_text"]
78
+ print("Перевод (KK):", translated_text)
79
+
80
+ # =========================
81
+ # Шаг 2 — Озвучка
82
+ # =========================
83
+ inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
84
+
85
  with torch.no_grad():
86
  output = tts_model(**inputs)
87
+
88
+ waveform = output.waveform.squeeze().cpu().numpy()
 
 
 
 
89
  audio = (waveform * 32767).astype("int16")
90
  sampling_rate = tts_model.config.sampling_rate
91
+
92
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
93
+ wavfile.write(f.name, sampling_rate, audio)
94
+ audio_path = f.name
95
+
96
+ # =========================
97
+ # Шаг 3 — Сохранение изображения
98
+ # =========================
99
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
100
+ if image.mode != "RGB":
101
+ image = image.convert("RGB")
102
+ image.save(f.name)
103
+ img_path = f.name
104
+
105
+ # =========================
106
+ # Шаг 4 — Генерация видео
107
+ # =========================
 
 
 
108
  client = Client(TALKING_HEAD_SPACE)
109
+
 
 
 
 
110
  result = client.predict(
111
  image_path=handle_file(img_path),
112
  audio_path=handle_file(audio_path),
 
114
  steps=10,
115
  api_name="/process_image_audio"
116
  )
117
+
118
+ if isinstance(result, tuple):
119
+ video_path = result[0]
 
 
 
 
 
 
 
 
 
 
 
120
  else:
121
+ raise ValueError("Видео не получено!")
122
+
123
+ error_msg = "Видео успешно создано!"
124
+
 
125
  except Exception as e:
126
  error_msg = f"❌ Ошибка: {str(e)}"
 
 
127
  traceback.print_exc()
128
+
129
  finally:
130
+ for p in [audio_path, img_path]:
131
+ if p and os.path.exists(p):
132
+ try:
133
+ os.remove(p)
134
+ except:
135
+ pass
136
+
 
 
 
 
 
 
 
 
137
  return video_path, error_msg
138
 
139
+
140
+ # =========================
141
+ # Gradio Интерфейс
142
+ # =========================
143
+
144
+ title = "Бейне Оқытушы"
145
+
146
  description = """
147
+ Суретіңізді жүктеп, дәріс мәтінін орыс тілінде енгізіңіз.
148
+ Жүйе автоматты түрде қазақ тіліне аударады және бейне жасайды!
 
 
 
 
149
 
150
+ **Талаптар:**
151
+ - Фото: бет анық көрінетін
152
+ - Мәтін: орыс тілінде (500 таңбаға дейін)
153
+ """
 
 
154
 
155
  iface = gr.Interface(
156
  fn=inference,
157
  inputs=[
158
+ gr.Image(type="pil", label="📸 Фото дәріскер"),
159
  gr.Textbox(
160
+ lines=5,
161
+ label="📝 Дәріс мәтінірыс тілінде)",
162
+ placeholder="500 таңбаға дейін..."
163
  )
164
  ],
165
  outputs=[
166
+ gr.Video(label="🎬 Дайын бейне"),
167
+ gr.Textbox(label="ℹ️ Мәртебе")
168
  ],
169
  title=title,
170
  description=description,
171
+ cache_examples=False,
172
+ flagging_mode="never"
 
173
  )
174
 
175
  if __name__ == "__main__":
176
+ iface.launch()