Spaces:
Paused
Paused
File size: 6,884 Bytes
a053ac4 d015882 a053ac4 cde87a2 a053ac4 d13204f a053ac4 dc0e2d8 a053ac4 0ffcf0b a053ac4 cde87a2 dc0e2d8 e0ffe9c a053ac4 60ebaee a053ac4 47ab7c3 be38dd6 dc0e2d8 60ebaee be38dd6 290d660 60ebaee a074d8f 60ebaee d015882 60ebaee 47ab7c3 60ebaee a053ac4 60ebaee 47ab7c3 60ebaee a053ac4 91b1ea4 265fc3e 91b1ea4 a053ac4 60ebaee a053ac4 06a4adc d13204f a053ac4 e0ffe9c a053ac4 5d950e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os
import re
import asyncio
import torch
import soundfile as sf
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer
from deep_translator import GoogleTranslator
from textblob import TextBlob
import nltk
from parler_tts import ParlerTTSForConditionalGeneration
# Flask setup
dir_path = os.path.dirname(os.path.realpath(__file__))
app = Flask(__name__, static_folder="static", static_url_path="")
CORS(app)
torch.set_num_threads(2)
# Paths
AUDIO_FOLDER = '/static/audio'
os.makedirs(AUDIO_FOLDER, exist_ok=True)
try:
nltk.data.find('corpora/brown')
except LookupError:
nltk.download('brown')
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt')
nltk.download('punkt_tab')
ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay", cache_dir=os.getenv("TRANSFORMERS_CACHE"))
class ChatBot:
def __init__(self):
self.chat_history_ids = None
self.bot_input_ids = None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# Parler-TTS Setup
self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay")
self.tts_model = torch.ao.quantization.quantize_dynamic(self.tts_model, {torch.nn.Linear}, dtype=torch.qint8)
self.tts_tokenizer = AutoTokenizer.from_pretrained("doublesizebed/parler-tts-mini-malay")
self.description_tokenizer = AutoTokenizer.from_pretrained(self.tts_model.config.text_encoder._name_or_path)
async def get_response(self, user_input, gender):
def build_prompt(user_question):
# 1) Mandate at top
instructions = (
"Never introduce yourself. "
"After your concise answer, ask exactly one relevant follow-up question.\n\n"
)
# 2) Few‑shot examples
demos = (
"Q: What is photosynthesis?\n"
"Answer: Photosynthesis lets plants convert sunlight into energy. Which plants interest you most?\n\n"
"Q: How do I make tea?\n"
"Answer: Steep tea leaves in hot water for 3-5 minutes, then serve. Do you prefer green or black tea?\n\n"
)
# 3) The actual user query
query = f"Q: {user_question}\nAnswer:"
return instructions + demos + query
full_prompt = build_prompt(user_input)
prompt_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids
if self.chat_history_ids is None:
self.chat_history_ids = prompt_ids
else:
self.chat_history_ids = torch.cat([self.chat_history_ids, prompt_ids], dim=-1)
self.model.eval()
output = self.model.generate(
self.chat_history_ids,
max_length=self.chat_history_ids.shape[-1] + 128,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=True,
temperature=0.5,
top_p=0.9,
top_k=50,
eos_token_id=self.tokenizer.eos_token_id,
)
# update history so next turn continues the convo
self.chat_history_ids = output
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
# Remove the prompt if it's echoed back
if generated_text.startswith(full_prompt):
generated_text = generated_text[len(full_prompt):].strip()
def clean_response(text):
cleaned_text = re.sub(r"(?m)^(Q:|Answer:).*\n?", "", text)
return cleaned_text.strip()
final_text = clean_response(generated_text)
blob = TextBlob(final_text)
nouns = blob.noun_phrases
masked_sentence = final_text
for i, noun in enumerate(nouns):
placeholder = f"<<<noun_{i}>>>"
masked_sentence = re.sub(re.escape(noun), placeholder, masked_sentence, flags=re.IGNORECASE)
translated_masked_sentence = GoogleTranslator(source='en', target='ms').translate(masked_sentence)
def restore_placeholders(text, nouns_list):
def replacer(match):
index = int(match.group(1))
return nouns_list[index]
return re.sub(r"<<<\s*noun_(\d+)\s*>>>", replacer, text, flags=re.IGNORECASE)
final_sentence = restore_placeholders(translated_masked_sentence, nouns)
audio_file_path = await self.text_to_speech(final_sentence, gender)
return final_sentence, audio_file_path
async def text_to_speech(self, text, gender):
if gender.lower() == "male":
description = "A male speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
else:
description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
desc_inputs = self.description_tokenizer(description, return_tensors="pt", padding=True).to(self.device)
text_inputs = self.tts_tokenizer(text, return_tensors="pt", padding=True).to(self.device)
self.tts_model.eval()
generation = self.tts_model.generate(
input_ids=desc_inputs.input_ids,
attention_mask=desc_inputs.attention_mask,
prompt_input_ids=text_inputs.input_ids,
prompt_attention_mask=text_inputs.attention_mask
)
audio_arr = generation.cpu().numpy().squeeze()
output_filename = f"response.wav"
output_path = os.path.join(AUDIO_FOLDER, output_filename)
sf.write(output_path, audio_arr, self.tts_model.config.sampling_rate)
return output_filename
chatbot = ChatBot()
@app.route("/")
def serve_index():
return send_from_directory(app.static_folder, "index.html")
@app.route('/chat', methods=['POST'])
def chat_endpoint():
data = request.get_json()
user_text = data.get('message', '')
gender = data.get('gender', '')
if not user_text:
return jsonify({"error": "Empty message"}), 400
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
resp_text, wav_name = loop.run_until_complete(chatbot.get_response(user_text, gender))
loop.close()
url = f"audio/{wav_name}"
return jsonify({"response": resp_text, "audiofile": url})
@app.route("/static/audio/<path:filename>")
def serve_audio(filename):
return send_from_directory(AUDIO_FOLDER, filename)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860) |