doublesizebed commited on
Commit
60ebaee
·
1 Parent(s): a8bcefb

Initial Docker Space

Browse files
Files changed (2) hide show
  1. app.py +114 -52
  2. requirements.txt +1 -1
app.py CHANGED
@@ -24,62 +24,124 @@ CORS(app)
24
  AUDIO_FOLDER = os.path.join(dir_path, 'static', 'audio')
25
  os.makedirs(AUDIO_FOLDER, exist_ok=True)
26
 
27
- # Load language detection model
28
- lid_model = fasttext.load_model(
29
- hf_hub_download("doublesizebed/predict_malay_en", "lid_ms_en.bin")
30
- )
31
-
32
- def tokenize(text):
33
- tokens = text.lower().split()
34
- return [t.strip(string.punctuation) for t in tokens if t.strip(string.punctuation)]
35
-
36
- def detect_lang(token):
37
- label, _ = lid_model.predict(token)
38
- return label[0].replace("__label__", "").upper()
39
-
40
- # G2P models
41
- g2p_ms_tokenizer = AutoTokenizer.from_pretrained("doublesizebed/G2P_malay")
42
- g2p_ms_model = AutoModelForSeq2SeqLM.from_pretrained("doublesizebed/G2P_malay").to('cuda' if torch.cuda.is_available() else 'cpu')
43
- g2p_eng = make_g2p("eng", "eng-ipa")
44
-
45
- def predict_phonemes(word, lang):
46
- if lang == "MS":
47
- inputs = g2p_ms_tokenizer(word, return_tensors="pt", padding=True, truncation=True)
48
- inputs = inputs.to(g2p_ms_model.device)
49
- outputs = g2p_ms_model.generate(**inputs)
50
- return g2p_ms_tokenizer.decode(outputs[0], skip_special_tokens=True)
51
- else:
52
- tg = g2p_eng(word)
53
- return ' '.join(tg.to_sequence())
54
-
55
- # Chatbot setup
56
  class ChatBot:
57
  def __init__(self):
 
 
58
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
59
- # Load conversation model\ self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
60
- self.model = AutoModelForCausalLM.from_pretrained(
61
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
62
- ).to(self.device)
63
- self.chat_history = None
64
- # Parler TTS
65
- self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
66
- "doublesizebed/parler-tts-mini-malay"
67
- ).to(self.device)
68
- self.tts_text_tokenizer = AutoTokenizer.from_pretrained(
69
- self.tts_model.config.text_encoder._name_or_path
70
- )
71
- self.tts_desc_tokenizer = AutoTokenizer.from_pretrained(
72
- self.tts_model.config.text_encoder._name_or_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  )
74
- # NLTK\ nltk.download('brown')
75
- nltk.download('punkt')
76
- nltk.download('averaged_perceptron_tagger')
77
 
78
- async def chat(self, user_input, gender):
79
- # Build prompt ... (same as original)
80
- # Generate response\ # Translate & mask nouns\ # TTS generation...
81
- # Save WAV in static/audio and return filename
82
- return "Translated text", "response.wav"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  chatbot = ChatBot()
85
 
@@ -87,7 +149,7 @@ chatbot = ChatBot()
87
  def chat_endpoint():
88
  data = request.get_json()
89
  user_text = data.get('message', '')
90
- gender = data.get('gender', 'male')
91
  if not user_text:
92
  return jsonify({"error": "Empty message"}), 400
93
  resp_text, wav_name = asyncio.run(chatbot.chat(user_text, gender))
 
24
  AUDIO_FOLDER = os.path.join(dir_path, 'static', 'audio')
25
  os.makedirs(AUDIO_FOLDER, exist_ok=True)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class ChatBot:
28
  def __init__(self):
29
+ self.chat_history_ids = None
30
+ self.bot_input_ids = None
31
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
+ self.tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
33
+ self.model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").to(self.device)
34
+
35
+ try:
36
+ nltk.data.find('corpora/brown')
37
+ except LookupError:
38
+ nltk.download('brown')
39
+
40
+ try:
41
+ nltk.data.find('tokenizers/punkt')
42
+ nltk.data.find('tokenizers/punkt_tab')
43
+ except LookupError:
44
+ nltk.download('punkt')
45
+ nltk.download('punkt_tab')
46
+
47
+ # Parler-TTS Setup
48
+ self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained("doublesizebed/parler-tts-mini-malay").to(self.device)
49
+ self.tts_tokenizer = AutoTokenizer.from_pretrained("C:/Users/Honor/app/model")
50
+ self.description_tokenizer = AutoTokenizer.from_pretrained(self.tts_model.config.text_encoder._name_or_path)
51
+
52
+ async def get_response(self, user_input, gender):
53
+ def build_prompt(user_question):
54
+ # 1) Mandate at top
55
+ instructions = (
56
+ "Never introduce yourself. "
57
+ "After your concise answer, ask exactly one relevant follow-up question.\n\n"
58
+ )
59
+ # 2) Few‑shot examples
60
+ demos = (
61
+ "Q: What is photosynthesis?\n"
62
+ "Answer: Photosynthesis lets plants convert sunlight into energy. Which plants interest you most?\n\n"
63
+ "Q: How do I make tea?\n"
64
+ "Answer: Steep tea leaves in hot water for 3–5 minutes, then serve. Do you prefer green or black tea?\n\n"
65
+ )
66
+ # 3) The actual user query
67
+ query = f"Q: {user_question}\nAnswer:"
68
+ return instructions + demos + query
69
+
70
+ full_prompt = build_prompt(user_input)
71
+ prompt_ids = self.tokenizer(full_prompt, return_tensors="pt").input_ids.to(self.device)
72
+
73
+ if self.chat_history_ids is None:
74
+ self.chat_history_ids = prompt_ids
75
+ else:
76
+ self.chat_history_ids = torch.cat([self.chat_history_ids, prompt_ids], dim=-1)
77
+
78
+ output = self.model.generate(
79
+ self.chat_history_ids,
80
+ max_length=self.chat_history_ids.shape[-1] + 128,
81
+ pad_token_id=self.tokenizer.pad_token_id,
82
+ do_sample=True,
83
+ temperature=0.5,
84
+ top_p=0.9,
85
+ top_k=50,
86
+ eos_token_id=self.tokenizer.eos_token_id,
87
  )
 
 
 
88
 
89
+ # update history so next turn continues the convo
90
+ self.chat_history_ids = output
91
+
92
+ generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
93
+ # Remove the prompt if it's echoed back
94
+ if generated_text.startswith(full_prompt):
95
+ generated_text = generated_text[len(full_prompt):].strip()
96
+
97
+ def clean_response(text):
98
+ cleaned_text = re.sub(r"(?m)^(Q:|Answer:).*\n?", "", text)
99
+ return cleaned_text.strip()
100
+
101
+ final_text = clean_response(generated_text)
102
+
103
+ blob = TextBlob(final_text)
104
+ nouns = blob.noun_phrases
105
+
106
+ masked_sentence = final_text
107
+ for i, noun in enumerate(nouns):
108
+ placeholder = f"<<<noun_{i}>>>"
109
+ masked_sentence = re.sub(re.escape(noun), placeholder, masked_sentence, flags=re.IGNORECASE)
110
+
111
+ translated_masked_sentence = GoogleTranslator(source='en', target='ms').translate(masked_sentence)
112
+
113
+ def restore_placeholders(text, nouns_list):
114
+ def replacer(match):
115
+ index = int(match.group(1))
116
+ return nouns_list[index]
117
+ return re.sub(r"<<<\s*noun_(\d+)\s*>>>", replacer, text, flags=re.IGNORECASE)
118
+
119
+ final_sentence = restore_placeholders(translated_masked_sentence, nouns)
120
+
121
+ audio_file_path = await self.text_to_speech(final_sentence, gender)
122
+
123
+ return final_sentence, audio_file_path
124
+
125
+ async def text_to_speech(self, text, gender):
126
+ if gender.lower() == "male":
127
+ description = "A male speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
128
+ else:
129
+ description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
130
+
131
+ desc_inputs = self.description_tokenizer(description, return_tensors="pt", padding=True).to(self.device)
132
+ text_inputs = self.tts_tokenizer(text, return_tensors="pt", padding=True).to(self.device)
133
+
134
+ generation = self.tts_model.generate(
135
+ input_ids=desc_inputs.input_ids,
136
+ attention_mask=desc_inputs.attention_mask,
137
+ prompt_input_ids=text_inputs.input_ids,
138
+ prompt_attention_mask=text_inputs.attention_mask
139
+ )
140
+ audio_arr = generation.cpu().numpy().squeeze()
141
+ output_filename = f"response.wav"
142
+ output_path = os.path.join(AUDIO_FOLDER, output_filename)
143
+ sf.write(output_path, audio_arr, self.tts_model.config.sampling_rate)
144
+ return output_filename
145
 
146
  chatbot = ChatBot()
147
 
 
149
  def chat_endpoint():
150
  data = request.get_json()
151
  user_text = data.get('message', '')
152
+ gender = data.get('gender', '')
153
  if not user_text:
154
  return jsonify({"error": "Empty message"}), 400
155
  resp_text, wav_name = asyncio.run(chatbot.chat(user_text, gender))
requirements.txt CHANGED
@@ -5,7 +5,7 @@ transformers>=4.30
5
  torch
6
  fasttext
7
  deep-translator
8
- textblob
9
  parler-tts
10
  soundfile
11
  nltk
 
5
  torch
6
  fasttext
7
  deep-translator
8
+ textblob==0.17.1
9
  parler-tts
10
  soundfile
11
  nltk