File size: 4,417 Bytes
ebbacfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
from transformers import (
    AutoProcessor,
    BlipForConditionalGeneration,
    pipeline,
    SpeechT5Processor,
    SpeechT5ForTextToSpeech,
    SpeechT5HifiGan
)
from PIL import Image

# Устройство
device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------------------------------------
# 1) IMAGE → CAPTION (BLIP)
# ---------------------------------------------------------
caption_model_name = "Salesforce/blip-image-captioning-base"
caption_processor = AutoProcessor.from_pretrained(caption_model_name)
caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_name).to(device)

def generate_caption(image: Image.Image) -> str:
    inputs = caption_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        output_ids = caption_model.generate(**inputs, max_length=30)
    caption = caption_processor.decode(output_ids[0], skip_special_tokens=True)
    return caption

# ---------------------------------------------------------
# 2) CAPTION → FAIRY TALE (Flan-T5)
# ---------------------------------------------------------
# Используем flan-t5-base или flan-t5-large (если есть память)
story_model = pipeline(
    "text2text-generation",
    model="google/flan-t5-base",
    max_new_tokens=180,
    device=0 if device == "cuda" else -1,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
)

def generate_fairy_tale(caption: str) -> str:
    prompt = (
        "You are a kind storyteller for young children. "
        "Based on the following description, create a short, gentle, and imaginative fairy tale (3–4 sentences):\n\n"
        f"Image description: {caption}\n\n"
        "Fairy tale:"
    )
    result = story_model(
        prompt,
        temperature=0.9,
        top_p=0.92,
        do_sample=True
    )[0]["generated_text"]
    return result.strip()

# ---------------------------------------------------------
# 3) FAIRY TALE → SPEECH (SpeechT5 + HiFi-GAN)
# ---------------------------------------------------------
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)

# Используем фиксированный speaker embedding для стабильности
# (можно загрузить из датасета, но для демо — random с фиксированным seed)
torch.manual_seed(42)
speaker_embedding = torch.randn(1, 512).to(device)

def text_to_speech(text: str):
    # Ограничим длину, чтобы избежать переполнения
    text = text[:200]
    inputs = tts_processor(text=text, return_tensors="pt").to(device)
    with torch.no_grad():
        speech = tts_model.generate_speech(
            inputs["input_ids"],
            speaker_embedding,
            vocoder=vocoder
        )
    audio = speech.cpu().numpy()
    sample_rate = 16000
    return (sample_rate, audio)

# ---------------------------------------------------------
# FULL PIPELINE
# ---------------------------------------------------------
def process_drawing(image):
    if image is None:
        raise gr.Error("Please upload a drawing.")
    
    caption = generate_caption(image)
    tale = generate_fairy_tale(caption)
    audio = text_to_speech(tale)
    
    return caption, tale, audio

# ---------------------------------------------------------
# GRADIO INTERFACE
# ---------------------------------------------------------
with gr.Blocks(title="Fairy Tale from Child's Drawing") as app:
    gr.Markdown("""
    ## 🌈 Magic Storyteller for Kids  
    Upload a child's drawing → Get a short fairy tale → Listen to it!
    """)

    with gr.Row():
        img_input = gr.Image(type="pil", label="Child's Drawing")
        audio_output = gr.Audio(label="Narrated Fairy Tale")

    caption_output = gr.Textbox(label="AI Description of the Drawing")
    tale_output = gr.Textbox(label="Generated Fairy Tale", lines=4)

    generate_btn = gr.Button("✨ Create Story")

    generate_btn.click(
        fn=process_drawing,
        inputs=[img_input],
        outputs=[caption_output, tale_output, audio_output]
    )

# Запуск
if __name__ == "__main__":
    app.launch()