import os import re import asyncio import torch import soundfile as sf from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS from transformers import AutoModelForCausalLM, AutoTokenizer from deep_translator import GoogleTranslator from textblob import TextBlob import nltk from parler_tts import ParlerTTSForConditionalGeneration # Flask setup dir_path = os.path.dirname(os.path.realpath(__file__)) app = Flask(__name__, static_folder="static", static_url_path="") CORS(app) torch.set_num_threads(2) # Paths AUDIO_FOLDER = '/static/audio' os.makedirs(AUDIO_FOLDER, exist_ok=True) try: nltk.data.find('corpora/brown') except LookupError: nltk.download('brown') try: nltk.data.find('tokenizers/punkt') nltk.data.find('tokenizers/punkt_tab') except LookupError: nltk.download('punkt') nltk.download('punkt_tab') ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay", cache_dir=os.getenv("TRANSFORMERS_CACHE")) class ChatBot: def __init__(self): self.chat_history_ids = None self.bot_input_ids = None self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_fast=False) self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") # Parler-TTS Setup self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay") self.tts_model = torch.ao.quantization.quantize_dynamic(self.tts_model, {torch.nn.Linear}, dtype=torch.qint8) self.tts_tokenizer = AutoTokenizer.from_pretrained("doublesizebed/parler-tts-mini-malay") self.description_tokenizer = AutoTokenizer.from_pretrained(self.tts_model.config.text_encoder._name_or_path) async def get_response(self, user_input, gender): def build_prompt(user_question): # 1) Mandate at top instructions = ( "Never introduce yourself. " "After your concise answer, ask exactly one relevant follow-up question.\n\n" ) # 2) Few‑shot examples demos = ( "Q: What is photosynthesis?\n" "Answer: Photosynthesis lets plants convert sunlight into energy. Which plants interest you most?\n\n" "Q: How do I make tea?\n" "Answer: Steep tea leaves in hot water for 3-5 minutes, then serve. Do you prefer green or black tea?\n\n" ) # 3) The actual user query query = f"Q: {user_question}\nAnswer:" return instructions + demos + query full_prompt = build_prompt(user_input) prompt_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids if self.chat_history_ids is None: self.chat_history_ids = prompt_ids else: self.chat_history_ids = torch.cat([self.chat_history_ids, prompt_ids], dim=-1) self.model.eval() output = self.model.generate( self.chat_history_ids, max_length=self.chat_history_ids.shape[-1] + 128, pad_token_id=self.tokenizer.pad_token_id, do_sample=True, temperature=0.5, top_p=0.9, top_k=50, eos_token_id=self.tokenizer.eos_token_id, ) # update history so next turn continues the convo self.chat_history_ids = output generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True) # Remove the prompt if it's echoed back if generated_text.startswith(full_prompt): generated_text = generated_text[len(full_prompt):].strip() def clean_response(text): cleaned_text = re.sub(r"(?m)^(Q:|Answer:).*\n?", "", text) return cleaned_text.strip() final_text = clean_response(generated_text) blob = TextBlob(final_text) nouns = blob.noun_phrases masked_sentence = final_text for i, noun in enumerate(nouns): placeholder = f"<<>>" masked_sentence = re.sub(re.escape(noun), placeholder, masked_sentence, flags=re.IGNORECASE) translated_masked_sentence = GoogleTranslator(source='en', target='ms').translate(masked_sentence) def restore_placeholders(text, nouns_list): def replacer(match): index = int(match.group(1)) return nouns_list[index] return re.sub(r"<<<\s*noun_(\d+)\s*>>>", replacer, text, flags=re.IGNORECASE) final_sentence = restore_placeholders(translated_masked_sentence, nouns) audio_file_path = await self.text_to_speech(final_sentence, gender) return final_sentence, audio_file_path async def text_to_speech(self, text, gender): if gender.lower() == "male": description = "A male speaker delivers a slightly expressive and animated speech with a moderate speed and pitch." else: description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch." desc_inputs = self.description_tokenizer(description, return_tensors="pt", padding=True).to(self.device) text_inputs = self.tts_tokenizer(text, return_tensors="pt", padding=True).to(self.device) self.tts_model.eval() generation = self.tts_model.generate( input_ids=desc_inputs.input_ids, attention_mask=desc_inputs.attention_mask, prompt_input_ids=text_inputs.input_ids, prompt_attention_mask=text_inputs.attention_mask ) audio_arr = generation.cpu().numpy().squeeze() output_filename = f"response.wav" output_path = os.path.join(AUDIO_FOLDER, output_filename) sf.write(output_path, audio_arr, self.tts_model.config.sampling_rate) return output_filename chatbot = ChatBot() @app.route("/") def serve_index(): return send_from_directory(app.static_folder, "index.html") @app.route('/chat', methods=['POST']) def chat_endpoint(): data = request.get_json() user_text = data.get('message', '') gender = data.get('gender', '') if not user_text: return jsonify({"error": "Empty message"}), 400 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) resp_text, wav_name = loop.run_until_complete(chatbot.get_response(user_text, gender)) loop.close() url = f"audio/{wav_name}" return jsonify({"response": resp_text, "audiofile": url}) @app.route("/static/audio/") def serve_audio(filename): return send_from_directory(AUDIO_FOLDER, filename) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)