AlserFurma commited on
Commit
af30315
·
verified ·
1 Parent(s): 09477a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -47
app.py CHANGED
@@ -6,15 +6,17 @@ from gradio_client import Client, handle_file
6
  import torch
7
  from transformers import VitsModel, AutoTokenizer
8
  import scipy.io.wavfile as wavfile
 
9
 
10
- # Загружаем TTS модель при старте
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {device}")
13
 
14
  try:
15
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-rus").to(device)
16
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
17
- print("TTS модель загружена успешно!")
 
18
  except Exception as e:
19
  raise RuntimeError(f"Ошибка загрузки TTS модели: {str(e)}")
20
 
@@ -26,42 +28,31 @@ def inference(image: Image.Image, text: str):
26
  video_path = None
27
  audio_path = None
28
  img_path = None
29
-
30
  try:
31
  # Валидация входных данных
32
  if image is None:
33
  raise ValueError("Загрузите изображение лектора!")
34
-
35
  if not text or not text.strip():
36
  raise ValueError("Введите текст лекции!")
37
-
38
  if len(text) > 500:
39
  raise ValueError("Текст слишком длинный! Используйте до 500 символов.")
40
-
41
  print(f"Генерация TTS для текста: '{text[:50]}...'")
42
-
43
  # Шаг 1: Генерация аудио через TTS
44
  torch.manual_seed(42)
45
  inputs = tts_tokenizer(text, return_tensors="pt").to(device)
46
-
47
  with torch.no_grad():
48
  output = tts_model(**inputs)
49
  waveform = output.waveform.squeeze().cpu().numpy()
50
-
51
  if waveform.size == 0:
52
  raise ValueError("TTS сгенерировал пустое аудио! Попробуйте другой текст.")
53
-
54
  # Конвертация в int16 для WAV
55
  audio = (waveform * 32767).astype("int16")
56
  sampling_rate = tts_model.config.sampling_rate
57
-
58
  # Сохранение аудио
59
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
60
  wavfile.write(audio_file.name, sampling_rate, audio)
61
  audio_path = audio_file.name
62
-
63
  print(f"TTS аудио сохранено: {audio_path} (длина: {len(waveform)/sampling_rate:.1f} сек)")
64
-
65
  # Шаг 2: Сохранение изображения
66
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_file:
67
  # Конвертация в RGB если нужно
@@ -69,27 +60,21 @@ def inference(image: Image.Image, text: str):
69
  image = image.convert('RGB')
70
  image.save(img_file.name, format='PNG')
71
  img_path = img_file.name
72
-
73
  print(f"Изображение сохранено: {img_path}")
74
-
75
  # Шаг 3: Вызов talking-head API
76
  print(f"Подключение к {TALKING_HEAD_SPACE}...")
77
  client = Client(TALKING_HEAD_SPACE)
78
-
79
  # Проверяем доступные API endpoints
80
  print("Доступные API методы:", client.view_api())
81
-
82
  # Вызов API с правильными параметрами
83
  result = client.predict(
84
- image=handle_file(img_path),
85
- audio=handle_file(audio_path),
86
  guidance_scale=3.0,
87
- inference_steps=10,
88
  api_name="/process_image_audio"
89
  )
90
-
91
  print(f"Результат API: {type(result)}")
92
-
93
  # Обработка результата
94
  if isinstance(result, tuple) and len(result) > 0:
95
  video_data = result[0]
@@ -103,16 +88,12 @@ def inference(image: Image.Image, text: str):
103
  video_path = video_data
104
  else:
105
  video_path = result
106
-
107
  print(f"Видео сгенерировано: {video_path}")
108
  error_msg = "✅ Видео успешно сгенерировано!"
109
-
110
  except Exception as e:
111
  error_msg = f"❌ Ошибка: {str(e)}"
112
  print(f"ОШИБКА: {error_msg}")
113
- import traceback
114
  traceback.print_exc()
115
-
116
  finally:
117
  # Очистка временных файлов
118
  if audio_path and os.path.exists(audio_path):
@@ -121,41 +102,28 @@ def inference(image: Image.Image, text: str):
121
  print(f"Удален временный файл: {audio_path}")
122
  except:
123
  pass
124
-
125
  if img_path and os.path.exists(img_path):
126
  try:
127
  os.remove(img_path)
128
  print(f"Удален временный файл: {img_path}")
129
  except:
130
  pass
131
-
132
  return video_path, error_msg
133
 
134
  # Интерфейс Gradio
135
  title = "🎓 Видео-лектор с TTS (Русский)"
136
- description = """
137
- Загрузите фото лектора и введите текст лекции.
138
- Система сгенерирует видео, где лектор "произносит" ваш текст!
139
-
140
  **Требования:**
141
  - Фото: фронтальное изображение лица
142
- - Текст: до 500 символов на русском языке
143
- """
144
-
145
- examples = [
146
- [
147
- "example_image.png",
148
- "Добрый день! Сегодня мы рассмотрим основы машинного обучения."
149
- ]
150
- ]
151
 
152
  iface = gr.Interface(
153
  fn=inference,
154
  inputs=[
155
  gr.Image(type="pil", label="📸 Фото лектора"),
156
  gr.Textbox(
157
- lines=5,
158
- placeholder="Введите текст лекции на русском языке (до 500 символов)...",
159
  label="📝 Текст лекции"
160
  )
161
  ],
@@ -165,8 +133,6 @@ iface = gr.Interface(
165
  ],
166
  title=title,
167
  description=description,
168
- flagging_mode="never",
169
- examples=None, # Добавьте примеры, если есть тестовые изображения
170
  cache_examples=False
171
  )
172
 
 
6
  import torch
7
  from transformers import VitsModel, AutoTokenizer
8
  import scipy.io.wavfile as wavfile
9
+ import traceback
10
 
11
+ # Загрузка моделей при старте
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"Using device: {device}")
14
 
15
  try:
16
+ # TTS модель для казахского (исправлено с rus на kaz)
17
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-kaz").to(device)
18
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz")
19
+ print("TTS модель (kaz) загружена успешно!")
20
  except Exception as e:
21
  raise RuntimeError(f"Ошибка загрузки TTS модели: {str(e)}")
22
 
 
28
  video_path = None
29
  audio_path = None
30
  img_path = None
 
31
  try:
32
  # Валидация входных данных
33
  if image is None:
34
  raise ValueError("Загрузите изображение лектора!")
 
35
  if not text or not text.strip():
36
  raise ValueError("Введите текст лекции!")
 
37
  if len(text) > 500:
38
  raise ValueError("Текст слишком длинный! Используйте до 500 символов.")
 
39
  print(f"Генерация TTS для текста: '{text[:50]}...'")
 
40
  # Шаг 1: Генерация аудио через TTS
41
  torch.manual_seed(42)
42
  inputs = tts_tokenizer(text, return_tensors="pt").to(device)
 
43
  with torch.no_grad():
44
  output = tts_model(**inputs)
45
  waveform = output.waveform.squeeze().cpu().numpy()
 
46
  if waveform.size == 0:
47
  raise ValueError("TTS сгенерировал пустое аудио! Попробуйте другой текст.")
 
48
  # Конвертация в int16 для WAV
49
  audio = (waveform * 32767).astype("int16")
50
  sampling_rate = tts_model.config.sampling_rate
 
51
  # Сохранение аудио
52
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
53
  wavfile.write(audio_file.name, sampling_rate, audio)
54
  audio_path = audio_file.name
 
55
  print(f"TTS аудио сохранено: {audio_path} (длина: {len(waveform)/sampling_rate:.1f} сек)")
 
56
  # Шаг 2: Сохранение изображения
57
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_file:
58
  # Конвертация в RGB если нужно
 
60
  image = image.convert('RGB')
61
  image.save(img_file.name, format='PNG')
62
  img_path = img_file.name
 
63
  print(f"Изображение сохранено: {img_path}")
 
64
  # Шаг 3: Вызов talking-head API
65
  print(f"Подключение к {TALKING_HEAD_SPACE}...")
66
  client = Client(TALKING_HEAD_SPACE)
 
67
  # Проверяем доступные API endpoints
68
  print("Доступные API методы:", client.view_api())
 
69
  # Вызов API с правильными параметрами
70
  result = client.predict(
71
+ image_path=handle_file(img_path),
72
+ audio_path=handle_file(audio_path),
73
  guidance_scale=3.0,
74
+ steps=10,
75
  api_name="/process_image_audio"
76
  )
 
77
  print(f"Результат API: {type(result)}")
 
78
  # Обработка результата
79
  if isinstance(result, tuple) and len(result) > 0:
80
  video_data = result[0]
 
88
  video_path = video_data
89
  else:
90
  video_path = result
 
91
  print(f"Видео сгенерировано: {video_path}")
92
  error_msg = "✅ Видео успешно сгенерировано!"
 
93
  except Exception as e:
94
  error_msg = f"❌ Ошибка: {str(e)}"
95
  print(f"ОШИБКА: {error_msg}")
 
96
  traceback.print_exc()
 
97
  finally:
98
  # Очистка временных файлов
99
  if audio_path and os.path.exists(audio_path):
 
102
  print(f"Удален временный файл: {audio_path}")
103
  except:
104
  pass
 
105
  if img_path and os.path.exists(img_path):
106
  try:
107
  os.remove(img_path)
108
  print(f"Удален временный файл: {img_path}")
109
  except:
110
  pass
 
111
  return video_path, error_msg
112
 
113
  # Интерфейс Gradio
114
  title = "🎓 Видео-лектор с TTS (Русский)"
115
+ description = """Загрузите фото лектора и введите текст лекции. Система сгенерирует видео, где лектор "произносит" ваш текст!
 
 
 
116
  **Требования:**
117
  - Фото: фронтальное изображение лица
118
+ - Текст: до 500 символов на русском языке"""
 
 
 
 
 
 
 
 
119
 
120
  iface = gr.Interface(
121
  fn=inference,
122
  inputs=[
123
  gr.Image(type="pil", label="📸 Фото лектора"),
124
  gr.Textbox(
125
+ lines=5,
126
+ placeholder="Введите текст лекции на русском языке (до 500 символов)...",
127
  label="📝 Текст лекции"
128
  )
129
  ],
 
133
  ],
134
  title=title,
135
  description=description,
 
 
136
  cache_examples=False
137
  )
138