Spaces:
Paused
Paused
| import os | |
| import re | |
| import asyncio | |
| import torch | |
| import soundfile as sf | |
| from flask import Flask, request, jsonify, send_from_directory | |
| from flask_cors import CORS | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from deep_translator import GoogleTranslator | |
| from textblob import TextBlob | |
| import nltk | |
| from parler_tts import ParlerTTSForConditionalGeneration | |
| # Flask setup | |
| dir_path = os.path.dirname(os.path.realpath(__file__)) | |
| app = Flask(__name__, static_folder="static", static_url_path="") | |
| CORS(app) | |
| torch.set_num_threads(2) | |
| # Paths | |
| AUDIO_FOLDER = '/static/audio' | |
| os.makedirs(AUDIO_FOLDER, exist_ok=True) | |
| try: | |
| nltk.data.find('corpora/brown') | |
| except LookupError: | |
| nltk.download('brown') | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| nltk.data.find('tokenizers/punkt_tab') | |
| except LookupError: | |
| nltk.download('punkt') | |
| nltk.download('punkt_tab') | |
| ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay", cache_dir=os.getenv("TRANSFORMERS_CACHE")) | |
| class ChatBot: | |
| def __init__(self): | |
| self.chat_history_ids = None | |
| self.bot_input_ids = None | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_fast=False) | |
| self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| # Parler-TTS Setup | |
| self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay") | |
| self.tts_model = torch.ao.quantization.quantize_dynamic(self.tts_model, {torch.nn.Linear}, dtype=torch.qint8) | |
| self.tts_tokenizer = AutoTokenizer.from_pretrained("doublesizebed/parler-tts-mini-malay") | |
| self.description_tokenizer = AutoTokenizer.from_pretrained(self.tts_model.config.text_encoder._name_or_path) | |
| async def get_response(self, user_input, gender): | |
| def build_prompt(user_question): | |
| # 1) Mandate at top | |
| instructions = ( | |
| "Never introduce yourself. " | |
| "After your concise answer, ask exactly one relevant follow-up question.\n\n" | |
| ) | |
| # 2) Few‑shot examples | |
| demos = ( | |
| "Q: What is photosynthesis?\n" | |
| "Answer: Photosynthesis lets plants convert sunlight into energy. Which plants interest you most?\n\n" | |
| "Q: How do I make tea?\n" | |
| "Answer: Steep tea leaves in hot water for 3-5 minutes, then serve. Do you prefer green or black tea?\n\n" | |
| ) | |
| # 3) The actual user query | |
| query = f"Q: {user_question}\nAnswer:" | |
| return instructions + demos + query | |
| full_prompt = build_prompt(user_input) | |
| prompt_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids | |
| if self.chat_history_ids is None: | |
| self.chat_history_ids = prompt_ids | |
| else: | |
| self.chat_history_ids = torch.cat([self.chat_history_ids, prompt_ids], dim=-1) | |
| self.model.eval() | |
| output = self.model.generate( | |
| self.chat_history_ids, | |
| max_length=self.chat_history_ids.shape[-1] + 128, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| do_sample=True, | |
| temperature=0.5, | |
| top_p=0.9, | |
| top_k=50, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # update history so next turn continues the convo | |
| self.chat_history_ids = output | |
| generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True) | |
| # Remove the prompt if it's echoed back | |
| if generated_text.startswith(full_prompt): | |
| generated_text = generated_text[len(full_prompt):].strip() | |
| def clean_response(text): | |
| cleaned_text = re.sub(r"(?m)^(Q:|Answer:).*\n?", "", text) | |
| return cleaned_text.strip() | |
| final_text = clean_response(generated_text) | |
| blob = TextBlob(final_text) | |
| nouns = blob.noun_phrases | |
| masked_sentence = final_text | |
| for i, noun in enumerate(nouns): | |
| placeholder = f"<<<noun_{i}>>>" | |
| masked_sentence = re.sub(re.escape(noun), placeholder, masked_sentence, flags=re.IGNORECASE) | |
| translated_masked_sentence = GoogleTranslator(source='en', target='ms').translate(masked_sentence) | |
| def restore_placeholders(text, nouns_list): | |
| def replacer(match): | |
| index = int(match.group(1)) | |
| return nouns_list[index] | |
| return re.sub(r"<<<\s*noun_(\d+)\s*>>>", replacer, text, flags=re.IGNORECASE) | |
| final_sentence = restore_placeholders(translated_masked_sentence, nouns) | |
| audio_file_path = await self.text_to_speech(final_sentence, gender) | |
| return final_sentence, audio_file_path | |
| async def text_to_speech(self, text, gender): | |
| if gender.lower() == "male": | |
| description = "A male speaker delivers a slightly expressive and animated speech with a moderate speed and pitch." | |
| else: | |
| description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch." | |
| desc_inputs = self.description_tokenizer(description, return_tensors="pt", padding=True).to(self.device) | |
| text_inputs = self.tts_tokenizer(text, return_tensors="pt", padding=True).to(self.device) | |
| self.tts_model.eval() | |
| generation = self.tts_model.generate( | |
| input_ids=desc_inputs.input_ids, | |
| attention_mask=desc_inputs.attention_mask, | |
| prompt_input_ids=text_inputs.input_ids, | |
| prompt_attention_mask=text_inputs.attention_mask | |
| ) | |
| audio_arr = generation.cpu().numpy().squeeze() | |
| output_filename = f"response.wav" | |
| output_path = os.path.join(AUDIO_FOLDER, output_filename) | |
| sf.write(output_path, audio_arr, self.tts_model.config.sampling_rate) | |
| return output_filename | |
| chatbot = ChatBot() | |
| def serve_index(): | |
| return send_from_directory(app.static_folder, "index.html") | |
| def chat_endpoint(): | |
| data = request.get_json() | |
| user_text = data.get('message', '') | |
| gender = data.get('gender', '') | |
| if not user_text: | |
| return jsonify({"error": "Empty message"}), 400 | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| resp_text, wav_name = loop.run_until_complete(chatbot.get_response(user_text, gender)) | |
| loop.close() | |
| url = f"audio/{wav_name}" | |
| return jsonify({"response": resp_text, "audiofile": url}) | |
| def serve_audio(filename): | |
| return send_from_directory(AUDIO_FOLDER, filename) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |