File size: 4,383 Bytes
a260f81
 
38be666
 
 
 
 
8a5ea2d
a260f81
8a5ea2d
 
 
 
 
 
 
 
 
 
 
 
 
d64103c
8a5ea2d
38be666
8a5ea2d
d64103c
8a5ea2d
 
38be666
a260f81
8a5ea2d
 
 
d64103c
8a5ea2d
d64103c
2488489
8a5ea2d
d64103c
 
 
8a5ea2d
d64103c
8a5ea2d
d64103c
 
 
 
8a5ea2d
d64103c
8a5ea2d
d64103c
a260f81
8a5ea2d
38be666
d64103c
8a5ea2d
 
 
 
a260f81
8a5ea2d
 
 
 
 
 
 
 
 
 
 
 
 
d64103c
8a5ea2d
d64103c
 
 
8a5ea2d
d64103c
 
 
 
 
 
8a5ea2d
 
 
d64103c
a260f81
8a5ea2d
38be666
8a5ea2d
38be666
 
 
 
8a5ea2d
2488489
38be666
 
 
 
2488489
8a5ea2d
a260f81
8a5ea2d
2488489
8a5ea2d
 
38be666
2488489
 
38be666
2488489
 
8a5ea2d
 
a260f81
 
 
38be666
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import torch
from transformers import pipeline, GPT2LMHeadModel, GPT2Tokenizer
from gtts import gTTS
import numpy as np
import tempfile
import os
import google.generativeai as genai

# Set Google GenAI API key from environment variable
#GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
#genai.configure(api_key=GOOGLE_API_KEY)
genai.configure(api_key="AIzaSyB3N9BHeIWs_8sdFK76PU-v9N6prcIq2Hw")
#model = genai.GenerativeModel("gemini-1.5-pro")
#chat = model.start_chat(history=[])

# Load GenAI model
print("Loading Google Generative AI model...")
gen_model = genai.GenerativeModel("gemini-1.5-pro")


# Load ASR
print("Loading ASR model...")
speech_to_text_pipeline = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")

# Load GPT-2
print("Loading GPT-2 model...")
response_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
response_model = GPT2LMHeadModel.from_pretrained("gpt2")
response_model.eval()

# Main logic
def process_input(emotion, audio_input, text_input):
    print(f"\n---\nEmotion: {emotion}")

    # Handle audio input
    audio_text = ""
    if audio_input is not None:
        print("Audio input detected. Transcribing...")
        try:
            sample_rate, audio_data = audio_input
            if len(audio_data) == 0 or np.all(audio_data == 0):
                print("Silent or empty audio.")
            else:
                audio_data = audio_data / np.max(np.abs(audio_data))
                audio_text = speech_to_text_pipeline({
                    "sampling_rate": sample_rate,
                    "array": audio_data
                })["text"]
                print(f"Audio transcription: {audio_text}")
        except Exception as e:
            print(f"Speech-to-text error: {e}")
            audio_text = ""

    # Combine input
    combined_input_text = (text_input or "") + " " + (audio_text or "")
    combined_input_text = combined_input_text.strip()
    print(f"User input: {combined_input_text}")

    if not combined_input_text:
        return "Please provide text or audio input.", None

    # Add emotion context
    prompt = f"The user feels {emotion}. Respond supportively: {combined_input_text}"
    print(f"Final prompt to model: {prompt}")

    # Use Google GenAI
    try:
        gen_response = gen_model.generate_content(prompt)
        text_output = gen_response.text.strip()
        print(f"Google GenAI response: {text_output}")
    except Exception as e:
        print(f"GenAI Error: {e}")
        # Fallback to GPT-2
        print("Falling back to GPT-2...")
        try:
            input_ids = response_tokenizer.encode(prompt, return_tensors='pt')[:, -512:]
            with torch.no_grad():
                output = response_model.generate(
                    input_ids=input_ids,
                    max_length=input_ids.shape[1] + 50,
                    num_beams=3,
                    temperature=0.8,
                    no_repeat_ngram_size=2,
                    early_stopping=True
                )
            text_output = response_tokenizer.decode(output[0], skip_special_tokens=True)
            print(f"GPT-2 fallback response: {text_output}")
        except Exception as gpt_error:
            print(f"GPT-2 Error: {gpt_error}")
            text_output = "Sorry, I couldn't generate a response."

    # TTS conversion
    try:
        print("Generating speech...")
        tts = gTTS(text_output)
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
        tts.save(temp_file.name)
        audio_output_path = temp_file.name
        print(f"TTS audio saved at: {audio_output_path}")
    except Exception as e:
        print(f"TTS Error: {e}")
        audio_output_path = None

    return text_output, audio_output_path

# Gradio Interface
iface = gr.Interface(
    fn=process_input,
    inputs=[
        gr.Radio(["positive", "neutral", "negative"], label="Your Emotion"),
        gr.Audio(type="numpy", label="Speak..."),
        gr.Textbox(label="Text Input", placeholder="Or type here..."),
    ],
    outputs=[
        gr.Textbox(label="AI Response"),
        gr.Audio(label="Spoken Response"),
    ],
    title="Emotion-Aware Multimodal AI Assistant",
    description="Choose your emotional state, then talk or type to the AI assistant. It responds based on your emotional context.",
)

if __name__ == "__main__":
    iface.launch()