AlserFurma commited on
Commit
444e569
·
verified ·
1 Parent(s): 1873d97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -135
app.py CHANGED
@@ -8,24 +8,20 @@ from transformers import VitsModel, AutoTokenizer, pipeline
8
  import scipy.io.wavfile as wavfile
9
  import traceback
10
  import random
11
- import json
12
- import re
13
 
14
  # =========================
15
  # Параметры
16
  # =========================
17
-
18
  TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head"
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
- print(f"Device set to use {device}")
22
 
23
  # =========================
24
  # Загрузка моделей
25
  # =========================
26
-
27
  try:
28
- # TTS (казахский)
29
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-kaz").to(device)
30
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz")
31
 
@@ -36,86 +32,51 @@ try:
36
  device=0 if device == "cuda" else -1
37
  )
38
 
39
- # Генерация учебных вопросов (стабильная CPU-модель)
40
  qa_model = pipeline(
41
  "text2text-generation",
42
- model="t5-base", # <-- ВАЖНО: существующая стабильная модель!
43
  device=0 if device == "cuda" else -1
44
  )
45
 
46
- print("Models loaded successfully!")
47
 
48
  except Exception as e:
49
- raise RuntimeError(f"Model loading error: {str(e)}")
50
 
51
 
52
  # =========================
53
- # Генерация учебного вопроса
54
  # =========================
55
-
56
  def generate_quiz(text: str):
 
57
  prompt = (
58
- "Сгенерируй учебный вопрос по тексту и дай один правильный и один неправильный ответ. "
59
- "Верни ТОЛЬКО JSON без комментариев:\n"
60
- "{\n"
61
- " \"question\": \"...\",\n"
62
- " \"correct\": \"...\",\n"
63
- " \"wrong\": \"...\"\n"
64
- "}\n"
65
- f"TEXT: {text}"
66
  )
67
-
68
- # 1. Генерация
69
- out = qa_model(prompt, max_new_tokens=200)[0]["generated_text"].strip()
70
-
71
- # 2. Повторная попытка при пустом выводе
72
- if not out:
73
- out = qa_model(prompt, max_new_tokens=200)[0]["generated_text"].strip()
74
- if not out:
75
- raise ValueError("Модель дважды вернула пустой ответ.")
76
-
77
- # 3. Извлечение JSON
78
  try:
79
- json_str = out[out.index("{"): out.rindex("}") + 1]
80
- except Exception:
81
- # fallback
82
- q = re.search(r'"?question"?\s*[:=]\s*[\'"](.+?)[\'"]', out)
83
- c = re.search(r'"?correct"?\s*[:=]\s*[\'"](.+?)[\'"]', out)
84
- w = re.search(r'"?wrong"?\s*[:=]\s*[\'"](.+?)[\'"]', out)
85
- if q and c and w:
86
- json_str = json.dumps({
87
- "question": q.group(1),
88
- "correct": c.group(1),
89
- "wrong": w.group(1)
90
- })
91
- else:
92
  raise ValueError(f"Модель вывела неподходящий формат:\n{out}")
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- json_str = json_str.replace("\n", "")
95
-
96
- try:
97
- data = json.loads(json_str)
98
- except Exception:
99
- data = json.loads(json_str.replace("'", "\""))
100
-
101
- question = data.get("question", "").strip()
102
- correct = data.get("correct", "").strip()
103
- wrong = data.get("wrong", "").strip()
104
-
105
- if not (question and correct and wrong):
106
- raise ValueError("JSON не содержит нужных полей")
107
-
108
- options = [correct, wrong]
109
- random.shuffle(options)
110
-
111
- return question, options, correct
112
-
113
-
114
- # =========================
115
- # Синтез речи
116
- # =========================
117
 
118
  def synthesize_audio(text_ru: str):
 
119
  translation = translator(text_ru, src_lang="rus_Cyrl", tgt_lang="kaz_Cyrl")
120
  text_kk = translation[0]["translation_text"]
121
 
@@ -124,20 +85,17 @@ def synthesize_audio(text_ru: str):
124
  output = tts_model(**inputs)
125
 
126
  waveform = output.waveform.squeeze().cpu().numpy()
127
- audio = (waveform * 32767).astype("int16")
128
- sr = getattr(tts_model.config, "sampling_rate", 22050)
129
 
130
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
131
- wavfile.write(tmp.name, sr, audio)
132
- tmp.close()
133
- return tmp.name
134
 
135
 
136
- # =========================
137
- # Talking Head
138
- # =========================
139
-
140
  def make_talking_head(image_path: str, audio_path: str):
 
141
  client = Client(TALKING_HEAD_SPACE)
142
  result = client.predict(
143
  image_path=handle_file(image_path),
@@ -146,107 +104,119 @@ def make_talking_head(image_path: str, audio_path: str):
146
  steps=10,
147
  api_name="/process_image_audio"
148
  )
149
- if isinstance(result, tuple):
150
- return result[0]
151
- return result
 
 
 
 
152
 
153
 
154
  # =========================
155
- # Шаг 1 — старт урока
156
  # =========================
157
-
158
  def start_lesson(image: Image.Image, text: str, state):
159
- if image is None:
160
- return None, "Загрузите фото", [], state
161
- if not text:
162
- return None, "Введите текст", [], state
163
- if len(text) > 500:
164
- return None, "Текст слишком длинный", [], state
165
 
166
  try:
167
  question, options, correct = generate_quiz(text)
 
 
168
 
169
- quiz_text = f"Вопрос: {question}. Варианты: 1) {options[0]} 2) {options[1]}"
170
- audio_path = synthesize_audio(quiz_text)
171
-
172
- tmpimg = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
173
- if image.mode != "RGB":
174
- image = image.convert("RGB")
175
  image.save(tmpimg.name)
176
  tmpimg.close()
 
177
 
178
- video_path = make_talking_head(tmpimg.name, audio_path)
179
 
180
- state = {
181
- "image_path": tmpimg.name,
182
- "correct": correct,
183
- "options": options
184
- }
185
 
186
- return video_path, question, options, state, state
 
 
 
 
187
 
188
  except Exception as e:
189
  traceback.print_exc()
190
- return None, f"Ошибка: {e}", [], state
191
 
192
 
193
- # =========================
194
- # Шаг 2 — реакция
195
- # =========================
196
-
197
  def answer_selected(selected_option: str, state):
 
198
  if not state:
199
- return None, "Ошибка: урок не запущен"
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- correct = state["correct"]
202
- image_path = state["image_path"]
203
 
204
- if selected_option == correct:
205
- reply_ru = "Молодец!"
206
- reply_ui = "Дұрыс!"
207
- else:
208
- reply_ru = f"Неправильно. Правильный ответ: {correct}"
209
- reply_ui = f"Қате. Дұрыс жауап: {correct}"
210
 
211
- audio_path = synthesize_audio(reply_ru)
212
- video_path = make_talking_head(image_path, audio_path)
213
 
214
- return video_path, reply_ui
 
 
215
 
216
 
217
  # =========================
218
- # Интерфейс
219
  # =========================
 
 
 
 
 
 
220
 
221
  with gr.Blocks() as demo:
222
- gr.Markdown("# 🎓 Интерактивный бейне-лектор")
223
 
224
  with gr.Row():
225
- with gr.Column():
226
- inp_image = gr.Image(type="pil", label="Фото лектора")
227
- inp_text = gr.Textbox(lines=4, label="Текст лекции (рус.)")
228
  btn_start = gr.Button("Запустить урок")
229
 
230
- with gr.Column():
231
- out_video = gr.Video(label="Видео лектора")
232
- out_question = gr.Markdown(label="Вопрос")
233
  btn_opt1 = gr.Button("Вариант 1")
234
  btn_opt2 = gr.Button("Вариант 2")
235
- out_react = gr.Video(label="Реакция")
236
- out_status = gr.Textbox(label="Статус")
237
 
238
- state = gr.State({})
239
 
 
240
  btn_start.click(
241
- start_lesson,
242
- [inp_image, inp_text, state],
243
- [out_video, out_question, btn_opt1, btn_opt2, state]
244
  )
245
 
246
- btn_opt1.click(answer_selected, [btn_opt1, state], [out_react, out_status])
247
- btn_opt2.click(answer_selected, [btn_opt2, state], [out_react, out_status])
248
 
249
  demo.load(lambda: "Готово", outputs=out_status)
250
 
251
- if __name__ == "__main__":
252
  demo.launch()
 
8
  import scipy.io.wavfile as wavfile
9
  import traceback
10
  import random
 
 
11
 
12
  # =========================
13
  # Параметры
14
  # =========================
 
15
  TALKING_HEAD_SPACE = "Skywork/skyreels-a1-talking-head"
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"Using device: {device}")
19
 
20
  # =========================
21
  # Загрузка моделей
22
  # =========================
 
23
  try:
24
+ # TTS модель (казахский)
25
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-kaz").to(device)
26
  tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kaz")
27
 
 
32
  device=0 if device == "cuda" else -1
33
  )
34
 
35
+ # Модель для генерации вопросов
36
  qa_model = pipeline(
37
  "text2text-generation",
38
+ model="google/flan-t5-small",
39
  device=0 if device == "cuda" else -1
40
  )
41
 
42
+ print(" Все модели успешно загружены!")
43
 
44
  except Exception as e:
45
+ raise RuntimeError(f" Ошибка загрузки моделей: {str(e)}")
46
 
47
 
48
  # =========================
49
+ # Вспомогательные функции
50
  # =========================
 
51
  def generate_quiz(text: str):
52
+ """Генерирует один вопрос и два варианта (correct, wrong) на русском языке."""
53
  prompt = (
54
+ "Сгенерируй один учебный вопрос по этому тексту и дай 1 правильный и 1 неправильный вариант ответа. "
55
+ "Формат вывода JSON: {\"question\": \"...\", \"correct\": \"...\", \"wrong\": \"...\"}. Текст: " + text
 
 
 
 
 
 
56
  )
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
+ out = qa_model(prompt, max_length=256)[0]["generated_text"]
59
+ # Пытаемся найти JSON в выводе модели
60
+ json_start = out.find("{")
61
+ json_end = out.rfind("}")
62
+ if json_start == -1 or json_end == -1:
 
 
 
 
 
 
 
 
63
  raise ValueError(f"Модель вывела неподходящий формат:\n{out}")
64
+ import json
65
+ data = json.loads(out[json_start: json_end+1])
66
+ question = data.get("question", "").strip()
67
+ correct = data.get("correct", "").strip()
68
+ wrong = data.get("wrong", "").strip()
69
+ if not (question and correct and wrong):
70
+ raise ValueError(f"Неполные данные:\n{out}")
71
+ options = [correct, wrong]
72
+ random.shuffle(options)
73
+ return question, options, correct
74
+ except Exception as e:
75
+ raise ValueError(f"Ошибка генерации вопроса:\n{str(e)}")
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  def synthesize_audio(text_ru: str):
79
+ """Переводит русскую строку на казахский, синтезирует аудио и возвращает путь к файлу .wav"""
80
  translation = translator(text_ru, src_lang="rus_Cyrl", tgt_lang="kaz_Cyrl")
81
  text_kk = translation[0]["translation_text"]
82
 
 
85
  output = tts_model(**inputs)
86
 
87
  waveform = output.waveform.squeeze().cpu().numpy()
88
+ audio = (waveform * 32767).astype('int16')
89
+ sampling_rate = getattr(tts_model.config, 'sampling_rate', 22050)
90
 
91
+ tmpf = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
92
+ wavfile.write(tmpf.name, sampling_rate, audio)
93
+ tmpf.close()
94
+ return tmpf.name
95
 
96
 
 
 
 
 
97
  def make_talking_head(image_path: str, audio_path: str):
98
+ """Вызывает SkyReels/Talking Head space и возвращает путь или URL видео."""
99
  client = Client(TALKING_HEAD_SPACE)
100
  result = client.predict(
101
  image_path=handle_file(image_path),
 
104
  steps=10,
105
  api_name="/process_image_audio"
106
  )
107
+
108
+ if isinstance(result, dict) and "video" in result:
109
+ return result["video"]
110
+ elif isinstance(result, str):
111
+ return result
112
+ else:
113
+ raise ValueError(f"Unexpected talking head result: {type(result)}")
114
 
115
 
116
  # =========================
117
+ # Основные обработчики для Gradio
118
  # =========================
 
119
  def start_lesson(image: Image.Image, text: str, state):
120
+ """Шаг 1: генерируем видео-лекцию с вопросом и вариантами ответа."""
121
+ if image is None or not text.strip() or len(text) > 500:
122
+ return None, "", [], [], state
 
 
 
123
 
124
  try:
125
  question, options, correct = generate_quiz(text)
126
+ quiz_ru = f"Вопрос: {question} Варианты: 1) {options[0]} 2) {options[1]}"
127
+ audio_path = synthesize_audio(quiz_ru)
128
 
129
+ tmpimg = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
130
+ if image.mode != 'RGB':
131
+ image = image.convert('RGB')
 
 
 
132
  image.save(tmpimg.name)
133
  tmpimg.close()
134
+ image_path = tmpimg.name
135
 
136
+ video_path = make_talking_head(image_path, audio_path)
137
 
138
+ state_data = {'image_path': image_path, 'correct': correct, 'options': options}
 
 
 
 
139
 
140
+ # удаляем временный аудио файл
141
+ try: os.remove(audio_path)
142
+ except: pass
143
+
144
+ return video_path, question, options, state_data, state_data
145
 
146
  except Exception as e:
147
  traceback.print_exc()
148
+ return None, f"Ошибка: {e}", [], [], state
149
 
150
 
 
 
 
 
151
  def answer_selected(selected_option: str, state):
152
+ """Шаг 2: пользователь выбирает вариант — генерируем реакцию лектора."""
153
  if not state:
154
+ return None, "Ошибка: отсутствует состояние урока. Сначала нажмите апустить урок'."
155
+ try:
156
+ correct = state.get('correct')
157
+ image_path = state.get('image_path')
158
+ options = state.get('options', [])
159
+
160
+ if selected_option == correct:
161
+ reaction_ru = "Молодец!"
162
+ display_message = "Дұрыс!"
163
+ else:
164
+ reaction_ru = f"Неправильно. Правильный ответ: {correct}"
165
+ display_message = f"Қате. Дұрыс жауап: {correct}"
166
 
167
+ audio_path = synthesize_audio(reaction_ru)
168
+ reaction_video = make_talking_head(image_path, audio_path)
169
 
170
+ try: os.remove(audio_path)
171
+ except: pass
 
 
 
 
172
 
173
+ return reaction_video, display_message
 
174
 
175
+ except Exception as e:
176
+ traceback.print_exc()
177
+ return None, f"Ошибка: {e}"
178
 
179
 
180
  # =========================
181
+ # Gradio UI
182
  # =========================
183
+ title = "🎓 Интерактивный бейне-лектор"
184
+ description = (
185
+ "Загрузите фото лектора и текст лекции (русский, до 500 символов). "
186
+ "Система создаст видео-лектора, задаст вопрос и предложит 2 варианта ответа. "
187
+ "Нажмите на один из вариантов — лектор коротко отреагирует (қазақша)."
188
+ )
189
 
190
  with gr.Blocks() as demo:
191
+ gr.Markdown(f"# {title}\n{description}")
192
 
193
  with gr.Row():
194
+ with gr.Column(scale=1):
195
+ inp_image = gr.Image(type='pil', label='📸 Фото лектора')
196
+ inp_text = gr.Textbox(lines=5, label='📝 Текст лекции (рус.)', placeholder='Введите текст...')
197
  btn_start = gr.Button("Запустить урок")
198
 
199
+ with gr.Column(scale=1):
200
+ out_video = gr.Video(label='🎬 Видео лектора')
201
+ out_question = gr.Markdown(label='Вопрос')
202
  btn_opt1 = gr.Button("Вариант 1")
203
  btn_opt2 = gr.Button("Вариант 2")
204
+ out_reaction_video = gr.Video(label='🎥 Реакция лектора')
205
+ out_status = gr.Textbox(label='ℹ️ Статус', interactive=False)
206
 
207
+ lesson_state = gr.State({})
208
 
209
+ # Привязки
210
  btn_start.click(
211
+ fn=start_lesson,
212
+ inputs=[inp_image, inp_text, lesson_state],
213
+ outputs=[out_video, out_question, btn_opt1, btn_opt2, lesson_state]
214
  )
215
 
216
+ btn_opt1.click(fn=answer_selected, inputs=[btn_opt1, lesson_state], outputs=[out_reaction_video, out_status])
217
+ btn_opt2.click(fn=answer_selected, inputs=[btn_opt2, lesson_state], outputs=[out_reaction_video, out_status])
218
 
219
  demo.load(lambda: "Готово", outputs=out_status)
220
 
221
+ if __name__ == '__main__':
222
  demo.launch()