chatbot / app.py
doublesizebed's picture
Updates
d930836
import os
import re
import asyncio
import torch
import soundfile as sf
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer
from deep_translator import GoogleTranslator
from textblob import TextBlob
import nltk
from parler_tts import ParlerTTSForConditionalGeneration
# Flask setup
dir_path = os.path.dirname(os.path.realpath(__file__))
app = Flask(__name__, static_folder="static", static_url_path="")
CORS(app)
torch.set_num_threads(2)
# Paths
AUDIO_FOLDER = '/static/audio'
os.makedirs(AUDIO_FOLDER, exist_ok=True)
try:
nltk.data.find('corpora/brown')
except LookupError:
nltk.download('brown')
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt')
nltk.download('punkt_tab')
ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay", cache_dir=os.getenv("TRANSFORMERS_CACHE"))
class ChatBot:
def __init__(self):
self.chat_history_ids = None
self.bot_input_ids = None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# Parler-TTS Setup
self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay")
self.tts_model = torch.ao.quantization.quantize_dynamic(self.tts_model, {torch.nn.Linear}, dtype=torch.qint8)
self.tts_tokenizer = AutoTokenizer.from_pretrained("doublesizebed/parler-tts-mini-malay")
self.description_tokenizer = AutoTokenizer.from_pretrained(self.tts_model.config.text_encoder._name_or_path)
async def get_response(self, user_input, gender):
def build_prompt(user_question):
# 1) Mandate at top
instructions = (
"Never introduce yourself. "
"After your concise answer, ask exactly one relevant follow-up question.\n\n"
)
# 2) Few‑shot examples
demos = (
"Q: What is photosynthesis?\n"
"Answer: Photosynthesis lets plants convert sunlight into energy. Which plants interest you most?\n\n"
"Q: How do I make tea?\n"
"Answer: Steep tea leaves in hot water for 3-5 minutes, then serve. Do you prefer green or black tea?\n\n"
)
# 3) The actual user query
query = f"Q: {user_question}\nAnswer:"
return instructions + demos + query
full_prompt = build_prompt(user_input)
prompt_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
if self.chat_history_ids is None:
self.chat_history_ids = prompt_ids
else:
self.chat_history_ids = torch.cat([self.chat_history_ids, prompt_ids], dim=-1)
self.model.eval()
output = self.model.generate(
self.chat_history_ids,
max_length=self.chat_history_ids.shape[-1] + 128,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=True,
temperature=0.5,
top_p=0.9,
top_k=50,
eos_token_id=self.tokenizer.eos_token_id,
)
# update history so next turn continues the convo
self.chat_history_ids = output
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
# Remove the prompt if it's echoed back
if generated_text.startswith(full_prompt):
generated_text = generated_text[len(full_prompt):].strip()
def clean_response(text):
cleaned_text = re.sub(r"(?m)^(Q:|Answer:).*\n?", "", text)
return cleaned_text.strip()
final_text = clean_response(generated_text)
blob = TextBlob(final_text)
nouns = blob.noun_phrases
masked_sentence = final_text
for i, noun in enumerate(nouns):
placeholder = f"<<<noun_{i}>>>"
masked_sentence = re.sub(re.escape(noun), placeholder, masked_sentence, flags=re.IGNORECASE)
translated_masked_sentence = GoogleTranslator(source='en', target='ms').translate(masked_sentence)
def restore_placeholders(text, nouns_list):
def replacer(match):
index = int(match.group(1))
return nouns_list[index]
return re.sub(r"<<<\s*noun_(\d+)\s*>>>", replacer, text, flags=re.IGNORECASE)
final_sentence = restore_placeholders(translated_masked_sentence, nouns)
audio_file_path = await self.text_to_speech(final_sentence, gender)
return final_sentence, audio_file_path
async def text_to_speech(self, text, gender):
if gender.lower() == "male":
description = "A male speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
else:
description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
desc_inputs = self.description_tokenizer(description, return_tensors="pt", padding=True).to(self.device)
text_inputs = self.tts_tokenizer(text, return_tensors="pt", padding=True).to(self.device)
self.tts_model.eval()
generation = self.tts_model.generate(
input_ids=desc_inputs.input_ids,
attention_mask=desc_inputs.attention_mask,
prompt_input_ids=text_inputs.input_ids,
prompt_attention_mask=text_inputs.attention_mask
)
audio_arr = generation.cpu().numpy().squeeze()
output_filename = f"response.wav"
output_path = os.path.join(AUDIO_FOLDER, output_filename)
sf.write(output_path, audio_arr, self.tts_model.config.sampling_rate)
return output_filename
chatbot = ChatBot()
@app.route("/")
def serve_index():
return send_from_directory(app.static_folder, "index.html")
@app.route('/chat', methods=['POST'])
def chat_endpoint():
data = request.get_json()
user_text = data.get('message', '')
gender = data.get('gender', '')
if not user_text:
return jsonify({"error": "Empty message"}), 400
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
resp_text, wav_name = loop.run_until_complete(chatbot.get_response(user_text, gender))
loop.close()
url = f"audio/{wav_name}"
return jsonify({"response": resp_text, "audiofile": url})
@app.route("/static/audio/<path:filename>")
def serve_audio(filename):
return send_from_directory(AUDIO_FOLDER, filename)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)