AlserFurma commited on
Commit
4c1bc35
·
verified ·
1 Parent(s): 6d84114

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -73
app.py CHANGED
@@ -2,93 +2,173 @@ import gradio as gr
2
  import os
3
  from PIL import Image
4
  import tempfile
 
5
  import torch
6
  from transformers import VitsModel, AutoTokenizer
7
  import scipy.io.wavfile as wavfile
8
- from gradio_client import Client, handle_file
9
- import traceback
10
-
11
- # Только CPU
12
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
13
- torch.set_num_threads(4)
14
-
15
- TALKING_HEAD = "Skywork/skyreels-a1-talking-head"
16
- model = None
17
- tokenizer = None
18
-
19
- def load_tts():
20
- global model, tokenizer
21
- if model is None:
22
- print("Загружаем TTS (каз)…")
23
- model = VitsModel.from_pretrained("facebook/mms-tts-kaz")
24
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz")
25
- print("TTS готова")
26
- return True
27
 
28
- def ru_to_kz_simple(text: str) -> str:
29
- rep = {
30
- "привет": "сәлем", "здравствуйте": "сәлеметсіз бе", "спасибо": "рахмет",
31
- "да": "иә", "нет": "жоқ", "сегодня": "бүгін", "завтра": "ертең",
32
- "урок": "сабақ", "лекция": "дәріс", "учитель": "мұғалім", "школа": "мектеп"
33
- }
34
- for ru, kz in rep.items():
35
- text = text.replace(ru, kz).replace(ru.capitalize(), kz.capitalize())
36
- return text
37
 
38
- def create_video(image: Image.Image, text: str):
39
- if not image or not text.strip():
40
- return None, "Загрузите фото и введите текст!"
 
 
 
41
 
42
- load_tts()
43
- text_kz = ru_to_kz_simple(text.strip())
44
 
 
 
 
 
 
 
45
  try:
46
- # TTS
47
- inputs = tokenizer(text_kz, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  with torch.no_grad():
49
- waveform = model(**inputs).waveform.squeeze().cpu().numpy()
50
-
51
- rate = model.config.sampling_rate
52
- audio_path = "/tmp/audio.wav"
53
- wavfile.write(audio_path, rate, (waveform * 32767).astype("int16"))
54
-
55
- # Изображение
56
- if image.mode != "RGB":
57
- image = image.convert("RGB")
58
- img_path = "/tmp/img.png"
59
- image.save(img_path)
60
-
61
- # Talking head
62
- client = Client(TALKING_HEAD)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  result = client.predict(
64
- image_path=handle_file(img_path),
65
- audio_path=handle_file(audio_path),
66
- guidance_scale=2.0,
67
- steps=8,
68
  api_name="/process_image_audio"
69
  )
70
-
71
- video_path = result[0] if isinstance(result, (list, tuple)) else result
72
- return video_path, "Бейне дайын!"
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  except Exception as e:
 
 
 
75
  traceback.print_exc()
76
- return None, f"Қате: {e}"
77
-
78
- # === Интерфейс ===
79
- with gr.Blocks(title="Бейне-лектор қазақша") as app:
80
- gr.Markdown("# Бейне-лектор қазақша\nФото + текст → говорящий видео-лектор")
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- with gr.Row():
83
- with gr.Column():
84
- img_in = gr.Image(label="Фото лектора", type="pil")
85
- txt_in = gr.Textbox(label="Текст лекции (русский)", lines=6, placeholder="Привет! Сегодня мы изучаем математику…")
86
- btn = gr.Button("Сделать видео", variant="primary")
87
-
88
- with gr.Column():
89
- video_out = gr.Video(label="Готовое видео")
90
- status = gr.Textbox(label="Статус", interactive=False)
91
 
92
- btn.click(create_video, [img_in, txt_in], [video_out, status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- app.launch(server_name="0.0.0.0", server_port=7860)
 
 
2
  import os
3
  from PIL import Image
4
  import tempfile
5
+ from gradio_client import Client, handle_file
6
  import torch
7
  from transformers import VitsModel, AutoTokenizer
8
  import scipy.io.wavfile as wavfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Загружаем TTS модель при старте
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ print(f"Using device: {device}")
 
 
 
 
 
 
13
 
14
+ try:
15
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-rus").to(device)
16
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus")
17
+ print("TTS модель загружена успешно!")
18
+ except Exception as e:
19
+ raise RuntimeError(f"Ошибка загрузки TTS модели: {str(e)}")
20
 
21
+ # Space для talking-head
22
+ TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head"
23
 
24
+ def inference(image: Image.Image, text: str):
25
+ error_msg = ""
26
+ video_path = None
27
+ audio_path = None
28
+ img_path = None
29
+
30
  try:
31
+ # Валидация входных данных
32
+ if image is None:
33
+ raise ValueError("Загрузите изображение лектора!")
34
+
35
+ if not text or not text.strip():
36
+ raise ValueError("Введите текст лекции!")
37
+
38
+ if len(text) > 500:
39
+ raise ValueError("Текст слишком длинный! Используйте до 500 символов.")
40
+
41
+ print(f"Генерация TTS для текста: '{text[:50]}...'")
42
+
43
+ # Шаг 1: Генерация аудио через TTS
44
+ torch.manual_seed(42)
45
+ inputs = tts_tokenizer(text, return_tensors="pt").to(device)
46
+
47
  with torch.no_grad():
48
+ output = tts_model(**inputs)
49
+ waveform = output.waveform.squeeze().cpu().numpy()
50
+
51
+ if waveform.size == 0:
52
+ raise ValueError("TTS сгенерировал пустое аудио! Попробуйте другой текст.")
53
+
54
+ # Конвертация в int16 для WAV
55
+ audio = (waveform * 32767).astype("int16")
56
+ sampling_rate = tts_model.config.sampling_rate
57
+
58
+ # Сохранение аудио
59
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_file:
60
+ wavfile.write(audio_file.name, sampling_rate, audio)
61
+ audio_path = audio_file.name
62
+
63
+ print(f"TTS аудио сохранено: {audio_path} (длина: {len(waveform)/sampling_rate:.1f} сек)")
64
+
65
+ # Шаг 2: Сохранение изображения
66
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_file:
67
+ # Конвертация в RGB если нужно
68
+ if image.mode != 'RGB':
69
+ image = image.convert('RGB')
70
+ image.save(img_file.name, format='PNG')
71
+ img_path = img_file.name
72
+
73
+ print(f"Изображение сохранено: {img_path}")
74
+
75
+ # Шаг 3: Вызов talking-head API
76
+ print(f"Подключение к {TALKING_HEAD_SPACE}...")
77
+ client = Client(TALKING_HEAD_SPACE)
78
+
79
+ # Проверяем доступные API endpoints
80
+ print("Доступные API методы:", client.view_api())
81
+
82
+ # Вызов API с правильными параметрами
83
  result = client.predict(
84
+ image=handle_file(img_path),
85
+ audio=handle_file(audio_path),
86
+ guidance_scale=3.0,
87
+ inference_steps=10,
88
  api_name="/process_image_audio"
89
  )
90
+
91
+ print(f"Результат API: {type(result)}")
92
+
93
+ # Обработка результата
94
+ if isinstance(result, tuple) and len(result) > 0:
95
+ video_data = result[0]
96
+ if isinstance(video_data, dict) and 'video' in video_data:
97
+ video_path = video_data['video']
98
+ elif isinstance(video_data, dict) and 'path' in video_data:
99
+ video_path = video_data['path']
100
+ elif isinstance(video_data, str):
101
+ video_path = video_data
102
+ else:
103
+ video_path = video_data
104
+ else:
105
+ video_path = result
106
+
107
+ print(f"Видео сгенерировано: {video_path}")
108
+ error_msg = "✅ Видео успешно сгенерировано!"
109
+
110
  except Exception as e:
111
+ error_msg = f"❌ Ошибка: {str(e)}"
112
+ print(f"ОШИБКА: {error_msg}")
113
+ import traceback
114
  traceback.print_exc()
115
+
116
+ finally:
117
+ # Очистка временных файлов
118
+ if audio_path and os.path.exists(audio_path):
119
+ try:
120
+ os.remove(audio_path)
121
+ print(f"Удален временный файл: {audio_path}")
122
+ except:
123
+ pass
124
+
125
+ if img_path and os.path.exists(img_path):
126
+ try:
127
+ os.remove(img_path)
128
+ print(f"Удален временный файл: {img_path}")
129
+ except:
130
+ pass
131
 
132
+ return video_path, error_msg
 
 
 
 
 
 
 
 
133
 
134
+ # Интерфейс Gradio
135
+ title = "🎓 Видео-лектор с TTS (Русский)"
136
+ description = """
137
+ Загрузите фото лектора и введите текст лекции.
138
+ Система сгенерирует видео, где лектор "произносит" ваш текст!
139
+
140
+ **Требования:**
141
+ - Фото: фронтальное изображение лица
142
+ - Текст: до 500 символов на русском языке
143
+ """
144
+
145
+ examples = [
146
+ [
147
+ "example_image.png",
148
+ "Добрый день! Сегодня мы рассмотрим основы машинного обучения."
149
+ ]
150
+ ]
151
+
152
+ iface = gr.Interface(
153
+ fn=inference,
154
+ inputs=[
155
+ gr.Image(type="pil", label="📸 Фото лектора"),
156
+ gr.Textbox(
157
+ lines=5,
158
+ placeholder="Введите текст лекции на русском языке (до 500 символов)...",
159
+ label="📝 Текст лекции"
160
+ )
161
+ ],
162
+ outputs=[
163
+ gr.Video(label="🎬 Готовое видео"),
164
+ gr.Textbox(label="ℹ️ Статус", interactive=False)
165
+ ],
166
+ title=title,
167
+ description=description,
168
+ flagging_mode="never",
169
+ examples=None, # Добавьте примеры, если есть тестовые изображения
170
+ cache_examples=False
171
+ )
172
 
173
+ if __name__ == "__main__":
174
+ iface.launch()