File size: 6,884 Bytes
a053ac4
 
 
 
 
 
 
d015882
a053ac4
 
cde87a2
a053ac4
 
 
 
d13204f
a053ac4
dc0e2d8
a053ac4
 
0ffcf0b
a053ac4
 
cde87a2
 
 
 
 
 
 
 
 
 
 
 
dc0e2d8
e0ffe9c
a053ac4
 
60ebaee
 
a053ac4
47ab7c3
be38dd6
dc0e2d8
60ebaee
be38dd6
 
290d660
60ebaee
 
 
 
 
 
 
 
 
 
 
 
 
 
a074d8f
60ebaee
 
 
 
 
 
d015882
60ebaee
 
 
 
 
 
47ab7c3
60ebaee
 
 
 
 
 
 
 
 
a053ac4
 
60ebaee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47ab7c3
60ebaee
 
 
 
 
 
 
 
 
 
 
a053ac4
 
 
91b1ea4
265fc3e
91b1ea4
 
a053ac4
 
 
 
60ebaee
a053ac4
 
06a4adc
 
 
 
d13204f
a053ac4
 
e0ffe9c
a053ac4
5d950e4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import re
import asyncio
import torch
import soundfile as sf
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer
from deep_translator import GoogleTranslator
from textblob import TextBlob
import nltk
from parler_tts import ParlerTTSForConditionalGeneration

# Flask setup
dir_path = os.path.dirname(os.path.realpath(__file__))
app = Flask(__name__, static_folder="static", static_url_path="")
CORS(app)
torch.set_num_threads(2)

# Paths
AUDIO_FOLDER = '/static/audio'
os.makedirs(AUDIO_FOLDER, exist_ok=True)

try:
    nltk.data.find('corpora/brown')
except LookupError:
    nltk.download('brown')

try:
    nltk.data.find('tokenizers/punkt')
    nltk.data.find('tokenizers/punkt_tab')
except LookupError:
    nltk.download('punkt')
    nltk.download('punkt_tab')

ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay", cache_dir=os.getenv("TRANSFORMERS_CACHE"))

class ChatBot:
    def __init__(self):
        self.chat_history_ids = None
        self.bot_input_ids = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_fast=False)
        self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

        # Parler-TTS Setup
        self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay")
        self.tts_model = torch.ao.quantization.quantize_dynamic(self.tts_model, {torch.nn.Linear}, dtype=torch.qint8)
        self.tts_tokenizer = AutoTokenizer.from_pretrained("doublesizebed/parler-tts-mini-malay")
        self.description_tokenizer = AutoTokenizer.from_pretrained(self.tts_model.config.text_encoder._name_or_path)
    
    async def get_response(self, user_input, gender):
        def build_prompt(user_question):
            # 1) Mandate at top
            instructions = (
                "Never introduce yourself. "
                "After your concise answer, ask exactly one relevant follow-up question.\n\n"
            )
            # 2) Few‑shot examples
            demos = (
                "Q: What is photosynthesis?\n"
                "Answer: Photosynthesis lets plants convert sunlight into energy. Which plants interest you most?\n\n"
                "Q: How do I make tea?\n"
                "Answer: Steep tea leaves in hot water for 3-5 minutes, then serve. Do you prefer green or black tea?\n\n"
            )
            # 3) The actual user query
            query = f"Q: {user_question}\nAnswer:"
            return instructions + demos + query
        
        full_prompt = build_prompt(user_input)
        prompt_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids

        if self.chat_history_ids is None:
            self.chat_history_ids = prompt_ids
        else:
            self.chat_history_ids = torch.cat([self.chat_history_ids, prompt_ids], dim=-1)

        self.model.eval()
        output = self.model.generate(
            self.chat_history_ids,
            max_length=self.chat_history_ids.shape[-1] + 128,
            pad_token_id=self.tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.5,
            top_p=0.9,
            top_k=50,
            eos_token_id=self.tokenizer.eos_token_id,
        )

        # update history so next turn continues the convo
        self.chat_history_ids = output

        generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
        # Remove the prompt if it's echoed back
        if generated_text.startswith(full_prompt):
            generated_text = generated_text[len(full_prompt):].strip()
    
        def clean_response(text):
            cleaned_text = re.sub(r"(?m)^(Q:|Answer:).*\n?", "", text)
            return cleaned_text.strip()
    
        final_text = clean_response(generated_text)

        blob = TextBlob(final_text)
        nouns = blob.noun_phrases
        
        masked_sentence = final_text
        for i, noun in enumerate(nouns):
            placeholder = f"<<<noun_{i}>>>"
            masked_sentence = re.sub(re.escape(noun), placeholder, masked_sentence, flags=re.IGNORECASE)

        translated_masked_sentence = GoogleTranslator(source='en', target='ms').translate(masked_sentence)
        
        def restore_placeholders(text, nouns_list):
            def replacer(match):
                index = int(match.group(1))
                return nouns_list[index]
            return re.sub(r"<<<\s*noun_(\d+)\s*>>>", replacer, text, flags=re.IGNORECASE)
        
        final_sentence = restore_placeholders(translated_masked_sentence, nouns)
            
        audio_file_path = await self.text_to_speech(final_sentence, gender)

        return final_sentence, audio_file_path

    async def text_to_speech(self, text, gender):
        if gender.lower() == "male":
            description = "A male speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
        else:
            description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
        
        desc_inputs = self.description_tokenizer(description, return_tensors="pt", padding=True).to(self.device)
        text_inputs = self.tts_tokenizer(text, return_tensors="pt", padding=True).to(self.device)

        self.tts_model.eval()
        generation = self.tts_model.generate(
            input_ids=desc_inputs.input_ids,
            attention_mask=desc_inputs.attention_mask,
            prompt_input_ids=text_inputs.input_ids,
            prompt_attention_mask=text_inputs.attention_mask
        )
        audio_arr = generation.cpu().numpy().squeeze()
        output_filename = f"response.wav"
        output_path = os.path.join(AUDIO_FOLDER, output_filename)
        sf.write(output_path, audio_arr, self.tts_model.config.sampling_rate)
        return output_filename

chatbot = ChatBot()

@app.route("/")
def serve_index():
    return send_from_directory(app.static_folder, "index.html")

@app.route('/chat', methods=['POST'])
def chat_endpoint():
    data = request.get_json()
    user_text = data.get('message', '')
    gender = data.get('gender', '')
    if not user_text:
        return jsonify({"error": "Empty message"}), 400
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    resp_text, wav_name = loop.run_until_complete(chatbot.get_response(user_text, gender))
    loop.close()
    url = f"audio/{wav_name}"
    return jsonify({"response": resp_text, "audiofile": url})

@app.route("/static/audio/<path:filename>")
def serve_audio(filename):
    return send_from_directory(AUDIO_FOLDER, filename)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)