mi55th commited on
Commit
ebbacfe
·
verified ·
1 Parent(s): 2cd068f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ AutoProcessor,
5
+ BlipForConditionalGeneration,
6
+ pipeline,
7
+ SpeechT5Processor,
8
+ SpeechT5ForTextToSpeech,
9
+ SpeechT5HifiGan
10
+ )
11
+ from PIL import Image
12
+
13
+ # Устройство
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # ---------------------------------------------------------
17
+ # 1) IMAGE → CAPTION (BLIP)
18
+ # ---------------------------------------------------------
19
+ caption_model_name = "Salesforce/blip-image-captioning-base"
20
+ caption_processor = AutoProcessor.from_pretrained(caption_model_name)
21
+ caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_name).to(device)
22
+
23
+ def generate_caption(image: Image.Image) -> str:
24
+ inputs = caption_processor(images=image, return_tensors="pt").to(device)
25
+ with torch.no_grad():
26
+ output_ids = caption_model.generate(**inputs, max_length=30)
27
+ caption = caption_processor.decode(output_ids[0], skip_special_tokens=True)
28
+ return caption
29
+
30
+ # ---------------------------------------------------------
31
+ # 2) CAPTION → FAIRY TALE (Flan-T5)
32
+ # ---------------------------------------------------------
33
+ # Используем flan-t5-base или flan-t5-large (если есть память)
34
+ story_model = pipeline(
35
+ "text2text-generation",
36
+ model="google/flan-t5-base",
37
+ max_new_tokens=180,
38
+ device=0 if device == "cuda" else -1,
39
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
40
+ )
41
+
42
+ def generate_fairy_tale(caption: str) -> str:
43
+ prompt = (
44
+ "You are a kind storyteller for young children. "
45
+ "Based on the following description, create a short, gentle, and imaginative fairy tale (3–4 sentences):\n\n"
46
+ f"Image description: {caption}\n\n"
47
+ "Fairy tale:"
48
+ )
49
+ result = story_model(
50
+ prompt,
51
+ temperature=0.9,
52
+ top_p=0.92,
53
+ do_sample=True
54
+ )[0]["generated_text"]
55
+ return result.strip()
56
+
57
+ # ---------------------------------------------------------
58
+ # 3) FAIRY TALE → SPEECH (SpeechT5 + HiFi-GAN)
59
+ # ---------------------------------------------------------
60
+ tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
61
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
62
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
63
+
64
+ # Используем фиксированный speaker embedding для стабильности
65
+ # (можно загрузить из датасета, но для демо — random с фиксированным seed)
66
+ torch.manual_seed(42)
67
+ speaker_embedding = torch.randn(1, 512).to(device)
68
+
69
+ def text_to_speech(text: str):
70
+ # Ограничим длину, чтобы избежать переполнения
71
+ text = text[:200]
72
+ inputs = tts_processor(text=text, return_tensors="pt").to(device)
73
+ with torch.no_grad():
74
+ speech = tts_model.generate_speech(
75
+ inputs["input_ids"],
76
+ speaker_embedding,
77
+ vocoder=vocoder
78
+ )
79
+ audio = speech.cpu().numpy()
80
+ sample_rate = 16000
81
+ return (sample_rate, audio)
82
+
83
+ # ---------------------------------------------------------
84
+ # FULL PIPELINE
85
+ # ---------------------------------------------------------
86
+ def process_drawing(image):
87
+ if image is None:
88
+ raise gr.Error("Please upload a drawing.")
89
+
90
+ caption = generate_caption(image)
91
+ tale = generate_fairy_tale(caption)
92
+ audio = text_to_speech(tale)
93
+
94
+ return caption, tale, audio
95
+
96
+ # ---------------------------------------------------------
97
+ # GRADIO INTERFACE
98
+ # ---------------------------------------------------------
99
+ with gr.Blocks(title="Fairy Tale from Child's Drawing") as app:
100
+ gr.Markdown("""
101
+ ## 🌈 Magic Storyteller for Kids
102
+ Upload a child's drawing → Get a short fairy tale → Listen to it!
103
+ """)
104
+
105
+ with gr.Row():
106
+ img_input = gr.Image(type="pil", label="Child's Drawing")
107
+ audio_output = gr.Audio(label="Narrated Fairy Tale")
108
+
109
+ caption_output = gr.Textbox(label="AI Description of the Drawing")
110
+ tale_output = gr.Textbox(label="Generated Fairy Tale", lines=4)
111
+
112
+ generate_btn = gr.Button("✨ Create Story")
113
+
114
+ generate_btn.click(
115
+ fn=process_drawing,
116
+ inputs=[img_input],
117
+ outputs=[caption_output, tale_output, audio_output]
118
+ )
119
+
120
+ # Запуск
121
+ if __name__ == "__main__":
122
+ app.launch()