BissakaAI commited on
Commit
4bb2275
·
verified ·
1 Parent(s): 9b9fdff

first_upload

Browse files
Files changed (5) hide show
  1. app.py +47 -0
  2. awari-project (1).ipynb +391 -0
  3. dockerfile +11 -0
  4. model.py +156 -0
  5. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from pydub import AudioSegment
4
+ import librosa
5
+ import uvicorn
6
+ import torch
7
+ import soundfile as sf
8
+
9
+ # import your existing functions
10
+ from your_model_file import textonly, speechonly
11
+
12
+ app = FastAPI(title="Hamid Speech API", version="1.0.0")
13
+
14
+ @app.get("/")
15
+ def root():
16
+ return {"message": "Welcome to Hamid AI Speech API"}
17
+
18
+ class TextRequest(BaseModel):
19
+ text: str
20
+
21
+ class SpeechRequest(BaseModel):
22
+ input_audio_path: str
23
+ wav_output_path: str
24
+
25
+
26
+ @app.post("/textonly")
27
+ def run_text(req: TextRequest):
28
+ result = textonly(req.text)
29
+ return {"response": result}
30
+
31
+
32
+ @app.post("/speechonly")
33
+ def run_speech(req: SpeechRequest):
34
+ # Convert input audio to WAV
35
+ audio = AudioSegment.from_file(req.input_audio_path)
36
+ audio = audio.set_frame_rate(16000).set_channels(1)
37
+ audio.export(req.wav_output_path, format="wav")
38
+
39
+ # Load WAV
40
+ speech, sr = librosa.load(req.wav_output_path, sr=16000)
41
+
42
+ llm_response, wav_path = speechonly(speech, output_wav_path=req.wav_output_path)
43
+
44
+ return {
45
+ "response": llm_response,
46
+ "wav_saved": wav_path
47
+ }
awari-project (1).ipynb ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": []
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": []
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {
21
+ "execution": {
22
+ "iopub.execute_input": "2025-12-11T15:43:45.235265Z",
23
+ "iopub.status.busy": "2025-12-11T15:43:45.235029Z",
24
+ "iopub.status.idle": "2025-12-11T15:43:45.340285Z",
25
+ "shell.execute_reply": "2025-12-11T15:43:45.339518Z",
26
+ "shell.execute_reply.started": "2025-12-11T15:43:45.235247Z"
27
+ },
28
+ "trusted": true
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "from kaggle_secrets import UserSecretsClient\n",
33
+ "user_secrets = UserSecretsClient()\n",
34
+ "secret_value_0 = user_secrets.get_secret(\"HF_TOKEN\")\n"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "execution": {
42
+ "iopub.execute_input": "2025-12-11T15:43:45.341357Z",
43
+ "iopub.status.busy": "2025-12-11T15:43:45.341102Z",
44
+ "iopub.status.idle": "2025-12-11T15:45:04.811675Z",
45
+ "shell.execute_reply": "2025-12-11T15:45:04.810916Z",
46
+ "shell.execute_reply.started": "2025-12-11T15:43:45.341333Z"
47
+ },
48
+ "trusted": true
49
+ },
50
+ "outputs": [],
51
+ "source": [
52
+ "pip install -U bitsandbytes"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "from transformers import (\n",
62
+ " AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,\n",
63
+ " AutoProcessor, SeamlessM4Tv2ForSpeechToText,\n",
64
+ " VitsModel # TTS\n",
65
+ ")\n",
66
+ "import torch\n",
67
+ "import soundfile as sf\n",
68
+ "import os\n",
69
+ "from kaggle_secrets import UserSecretsClient\n",
70
+ "\n",
71
+ "\n",
72
+ "# getting hftoken from kaggle secret\n",
73
+ "user_secrets = UserSecretsClient()\n",
74
+ "HF_TOKEN = user_secrets.get_secret(\"HF_TOKEN\")\n",
75
+ "print(\"hf_token retrieved\")\n",
76
+ "\n",
77
+ "\n",
78
+ "# using the bitsandbytes to quantize the model\n",
79
+ "bnb_config = BitsAndBytesConfig(load_in_8bit=True)\n",
80
+ "\n",
81
+ "#setting the device to use for runnning \n",
82
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
83
+ "\n",
84
+ "# loading Natlas model and tokenizer\n",
85
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
86
+ " \"NCAIR1/N-ATLaS\",\n",
87
+ " trust_remote_code=True,\n",
88
+ " token=HF_TOKEN\n",
89
+ ")\n",
90
+ "\n",
91
+ "\n",
92
+ "model = AutoModelForCausalLM.from_pretrained(\n",
93
+ " \"NCAIR1/N-ATLaS\",\n",
94
+ " quantization_config=bnb_config,\n",
95
+ " device_map=\"auto\",\n",
96
+ " trust_remote_code=True,\n",
97
+ " token=HF_TOKEN\n",
98
+ ")\n",
99
+ "\n",
100
+ "\n",
101
+ "\n",
102
+ "#an Asr model to convert speech to text\n",
103
+ "ASR_MODEL = \"facebook/seamless-m4t-v2-large\"\n",
104
+ "processor = AutoProcessor.from_pretrained(ASR_MODEL, token=HF_TOKEN)\n",
105
+ "asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(ASR_MODEL, token=HF_TOKEN).to(device)\n",
106
+ "asr_model.eval()\n",
107
+ "\n",
108
+ "\n",
109
+ "# model to covert text back to speech \n",
110
+ "# load for hausa igbo,yoruba and english\n",
111
+ "tts_models = {}\n",
112
+ "for lang, tts_name in {\n",
113
+ " \"yoruba\": \"facebook/mms-tts-yor\",\n",
114
+ " # \"igbo\": \"facebook/mms-tts-ibo\",\n",
115
+ " # \"hausa\": \"facebook/mms-tts-hau\",\n",
116
+ "}.items():\n",
117
+ " print(f\"Loading TTS model for {lang}\")\n",
118
+ " tts_proc = AutoProcessor.from_pretrained(tts_name, token=HF_TOKEN)\n",
119
+ " tts_mod = VitsModel.from_pretrained(tts_name, token=HF_TOKEN).to(device)\n",
120
+ " tts_mod.eval()\n",
121
+ " tts_models[lang] = {\"processor\": tts_proc, \"model\": tts_mod}\n",
122
+ "\n",
123
+ "print(\"All the tts models loaded successfully!\")\n",
124
+ "\n",
125
+ "\n"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {
132
+ "execution": {
133
+ "iopub.execute_input": "2025-12-11T15:45:04.813528Z",
134
+ "iopub.status.busy": "2025-12-11T15:45:04.813289Z",
135
+ "iopub.status.idle": "2025-12-11T15:49:56.343820Z",
136
+ "shell.execute_reply": "2025-12-11T15:49:56.343087Z",
137
+ "shell.execute_reply.started": "2025-12-11T15:45:04.813503Z"
138
+ },
139
+ "trusted": true
140
+ },
141
+ "outputs": [],
142
+ "source": [
143
+ "import torch\n",
144
+ "import soundfile as sf\n",
145
+ "\n",
146
+ "\n",
147
+ "\n",
148
+ "# create a function to load text input\n",
149
+ "def textonly(user_msg: str):\n",
150
+ " def format_prompt(messages):\n",
151
+ " return tokenizer.apply_chat_template(\n",
152
+ " messages,\n",
153
+ " add_generation_prompt=True,\n",
154
+ " tokenize=False\n",
155
+ " )\n",
156
+ "\n",
157
+ " chat = [\n",
158
+ " {\"role\": \"system\", \"content\": \"You are a helpful model trained by Awarri AI Technologies.\"},\n",
159
+ " {\"role\": \"user\", \"content\": user_msg}\n",
160
+ " ]\n",
161
+ "\n",
162
+ " final_text = format_prompt(chat)\n",
163
+ " inputs = tokenizer(final_text, return_tensors=\"pt\").to(model.device)\n",
164
+ "\n",
165
+ " with torch.no_grad():\n",
166
+ " output_ids = model.generate(\n",
167
+ " **inputs,\n",
168
+ " max_new_tokens=200,\n",
169
+ " temperature=0.1,\n",
170
+ " repetition_penalty=1.12\n",
171
+ " )\n",
172
+ "\n",
173
+ " response = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
174
+ " return response\n",
175
+ "\n",
176
+ "\n",
177
+ "\n",
178
+ "\n"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "#create a function to handle speech input\n",
188
+ "def speechonly(speech, output_wav_path=\"response.wav\"):\n",
189
+ " #the speech to text part \n",
190
+ " inputs = processor(audios=speech, sampling_rate=16000, return_tensors=\"pt\").to(device)\n",
191
+ " with torch.no_grad():\n",
192
+ " predicted_ids = asr_model.generate(inputs[\"input_features\"], max_new_tokens=300)\n",
193
+ " transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]\n",
194
+ "\n",
195
+ " print(\"\\nTRANSCRIPTION:\", transcription)\n",
196
+ "\n",
197
+ "\n",
198
+ " #using Natlas LLM to handle the response \n",
199
+ " def format_prompt(messages):\n",
200
+ " return tokenizer.apply_chat_template(\n",
201
+ " messages,\n",
202
+ " add_generation_prompt=True,\n",
203
+ " tokenize=False\n",
204
+ " )\n",
205
+ "\n",
206
+ " chat = [\n",
207
+ " {\"role\": \"system\", \"content\": \"Respond ONLY in the detected Nigerian language (Yoruba, Igbo, Hausa, Pidgin, English).\"},\n",
208
+ " {\"role\": \"user\", \"content\": transcription}\n",
209
+ " ]\n",
210
+ "\n",
211
+ " final_text = format_prompt(chat)\n",
212
+ " inputs_llm = tokenizer(final_text, return_tensors=\"pt\").to(model.device)\n",
213
+ "\n",
214
+ " with torch.no_grad():\n",
215
+ " output_ids = model.generate(\n",
216
+ " **inputs_llm,\n",
217
+ " max_new_tokens=200,\n",
218
+ " temperature=0.1,\n",
219
+ " repetition_penalty=1.12\n",
220
+ " )\n",
221
+ "\n",
222
+ " llm_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
223
+ " print(\"\\nllm response:\", llm_response)\n",
224
+ "\n",
225
+ "\n",
226
+ "\n",
227
+ "\n",
228
+ " #natlas is a multilingual model designed for nigerian languages \n",
229
+ " # its expected that it has a good understanding of the nigerian languages \n",
230
+ " # using it to detect the language of the user input \n",
231
+ " lang_prompt = [\n",
232
+ " {\"role\": \"system\", \"content\": \"You are a Nigerian language expert.\"},\n",
233
+ " {\"role\": \"user\", \"content\": f\"In which Nigerian language is this text: '{llm_response}'? Reply with only one of these: Yoruba, Igbo, Hausa, Pidgin, English.\"}\n",
234
+ " ]\n",
235
+ " lang_text = format_prompt(lang_prompt)\n",
236
+ " lang_inputs = tokenizer(lang_text, return_tensors=\"pt\").to(model.device)\n",
237
+ "\n",
238
+ " with torch.no_grad():\n",
239
+ " lang_output_ids = model.generate(**lang_inputs, max_new_tokens=10)\n",
240
+ "\n",
241
+ " llm_language = tokenizer.decode(lang_output_ids[0], skip_special_tokens=True).strip().lower()\n",
242
+ " print(\"\\nLLM DETECTED LANGUAGE:\", llm_language)\n",
243
+ "\n",
244
+ " # Picking TTS model based on LLM reply\n",
245
+ " \n",
246
+ " if llm_language not in tts_models:\n",
247
+ " llm_language = \"english\" \n",
248
+ "\n",
249
+ " tts_processor = tts_models[llm_language][\"processor\"]\n",
250
+ " tts_model = tts_models[llm_language][\"model\"]\n",
251
+ "\n",
252
+ "\n",
253
+ " #to generate speech \n",
254
+ "\n",
255
+ " # Process text\n",
256
+ " tts_inputs = tts_processor(text=llm_response, return_tensors=\"pt\").to(device)\n",
257
+ " with torch.no_grad():\n",
258
+ " output = tts_model(**tts_inputs)\n",
259
+ " audio_array = output.waveform.squeeze().cpu().numpy()\n",
260
+ "\n",
261
+ " # Save WAV\n",
262
+ " sf.write(output_wav_path, audio_array, 16000)\n",
263
+ " return llm_response, output_wav_path"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": null,
269
+ "metadata": {},
270
+ "outputs": [],
271
+ "source": [
272
+ "\n",
273
+ "\n",
274
+ "# Ask user for input type\n",
275
+ "userinput = input(\"Enter 'text' or 'audio': \").lower()\n",
276
+ "\n",
277
+ "if userinput == \"text\":\n",
278
+ " # Call text function\n",
279
+ " answer1 = textonly()\n",
280
+ " print(\"\\ntext response:\\n\", answer1)\n",
281
+ "\n",
282
+ "else:\n",
283
+ " # Load and preprocess audio\n",
284
+ " audio_path = \"/kaggle/input/recordings/Recording (3).m4a\" \n",
285
+ " audio = AudioSegment.from_file(audio_path)\n",
286
+ " audio = audio.set_frame_rate(16000).set_channels(1)\n",
287
+ " audio.export(\"/kaggle/working/audio.wav\", format=\"wav\")\n",
288
+ "\n",
289
+ " speech, sr = librosa.load(\"/kaggle/working/audio.wav\", sr=16000)\n",
290
+ " print(\"Converted audio loaded.\")\n",
291
+ "\n",
292
+ " # Call speech function\n",
293
+ " answer2 = speechonly(speech)\n",
294
+ " print(\"\\nAUDIO RESPONSE saved as:\", answer2)\n"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": null,
300
+ "metadata": {},
301
+ "outputs": [],
302
+ "source": [
303
+ "from fastapi import FastAPI\n",
304
+ "from pydantic import BaseModel\n",
305
+ "from pydub import AudioSegment\n",
306
+ "import librosa\n",
307
+ "import uvicorn\n",
308
+ "\n",
309
+ "app = FastAPI(title='Simple FastAPI App', version='1.0.0')\n",
310
+ "\n",
311
+ "@app.get(\"/\")\n",
312
+ "def root():\n",
313
+ " return {\"Message\": \"Welcome to Healthatlas Application\"}\n",
314
+ "\n",
315
+ "\n",
316
+ "\n",
317
+ "class TextRequest(BaseModel):\n",
318
+ " text: str\n",
319
+ "\n",
320
+ "\n",
321
+ "class SpeechRequest(BaseModel):\n",
322
+ " input_audio_path: str \n",
323
+ " wav_output_path: str \n",
324
+ "\n",
325
+ "\n",
326
+ "\n",
327
+ "@app.post(\"/textonly\")\n",
328
+ "def do_text(request: TextRequest):\n",
329
+ " answer1 = textonly(request.text)\n",
330
+ " print(\"\\nText response:\\n\", answer1)\n",
331
+ " return {\"response\": answer1}\n",
332
+ "\n",
333
+ "\n",
334
+ "@app.post(\"/speechonly\")\n",
335
+ "def run_speech(request: SpeechRequest):\n",
336
+ " audio = AudioSegment.from_file(request.input_audio_path)\n",
337
+ " audio = audio.set_frame_rate(16000).set_channels(1)\n",
338
+ " audio.export(request.wav_output_path, format=\"wav\")\n",
339
+ "\n",
340
+ " speech, sr = librosa.load(request.wav_output_path, sr=16000)\n",
341
+ " print(\"Converted audio loaded.\")\n",
342
+ "\n",
343
+ "\n",
344
+ " answer2 = speechonly(speech)\n",
345
+ "\n",
346
+ " return {\"response\": answer2, \"saved_wav\": request.wav_output_path}\n",
347
+ "\n",
348
+ "if __name__ == '__main__':\n",
349
+ " print(os.getenv('host'))\n",
350
+ " print(os.getenv('port'))\n",
351
+ " uvicorn.run(app,host=os.getenv(\"host\"),port=int(os.getenv(\"port\")))"
352
+ ]
353
+ }
354
+ ],
355
+ "metadata": {
356
+ "kaggle": {
357
+ "accelerator": "nvidiaTeslaT4",
358
+ "dataSources": [
359
+ {
360
+ "datasetId": 8987240,
361
+ "sourceId": 14109383,
362
+ "sourceType": "datasetVersion"
363
+ }
364
+ ],
365
+ "dockerImageVersionId": 31193,
366
+ "isGpuEnabled": true,
367
+ "isInternetEnabled": true,
368
+ "language": "python",
369
+ "sourceType": "notebook"
370
+ },
371
+ "kernelspec": {
372
+ "display_name": "zoomcamp-pwCLAhn6",
373
+ "language": "python",
374
+ "name": "python3"
375
+ },
376
+ "language_info": {
377
+ "codemirror_mode": {
378
+ "name": "ipython",
379
+ "version": 3
380
+ },
381
+ "file_extension": ".py",
382
+ "mimetype": "text/x-python",
383
+ "name": "python",
384
+ "nbconvert_exporter": "python",
385
+ "pygments_lexer": "ipython3",
386
+ "version": "3.12.4"
387
+ }
388
+ },
389
+ "nbformat": 4,
390
+ "nbformat_minor": 4
391
+ }
dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --upgrade pip
7
+ RUN pip install -r requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
model.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # your_model_file.py
2
+ from transformers import (
3
+ AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
4
+ AutoProcessor, SeamlessM4Tv2ForSpeechToText,
5
+ VitsModel
6
+ )
7
+ import torch
8
+ import soundfile as sf
9
+ import os
10
+
11
+ # --------------------------
12
+ # Device & config
13
+ # --------------------------
14
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # --------------------------
18
+ # Load LLM
19
+ # --------------------------
20
+ HF_TOKEN = os.getenv("HF_TOKEN") # Use environment variable for Spaces
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ "NCAIR1/N-ATLaS",
24
+ trust_remote_code=True,
25
+ token=HF_TOKEN
26
+ )
27
+
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ "NCAIR1/N-ATLaS",
30
+ quantization_config=bnb_config,
31
+ device_map="auto",
32
+ trust_remote_code=True,
33
+ token=HF_TOKEN
34
+ )
35
+
36
+ # --------------------------
37
+ # Load ASR
38
+ # --------------------------
39
+ ASR_MODEL = "facebook/seamless-m4t-v2-large"
40
+ processor = AutoProcessor.from_pretrained(ASR_MODEL, token=HF_TOKEN)
41
+ asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(ASR_MODEL, token=HF_TOKEN).to(device)
42
+ asr_model.eval()
43
+
44
+ # --------------------------
45
+ # Load Nigerian TTS models
46
+ # --------------------------
47
+ tts_models = {}
48
+ for lang, tts_name in {
49
+ "yoruba": "facebook/mms-tts-yor",
50
+ # "igbo": "facebook/mms-tts-ibo",
51
+ # "hausa": "facebook/mms-tts-hau",
52
+ }.items():
53
+ print(f"Loading TTS model for {lang}...")
54
+ tts_proc = AutoProcessor.from_pretrained(tts_name, token=HF_TOKEN)
55
+ tts_mod = VitsModel.from_pretrained(tts_name, token=HF_TOKEN).to(device)
56
+ tts_mod.eval()
57
+ tts_models[lang] = {"processor": tts_proc, "model": tts_mod}
58
+
59
+ print("✅ All models loaded successfully!")
60
+
61
+
62
+ # --------------------------
63
+ # TEXT FUNCTION
64
+ # --------------------------
65
+ def textonly(user_msg: str) -> str:
66
+ def format_prompt(messages):
67
+ return tokenizer.apply_chat_template(
68
+ messages,
69
+ add_generation_prompt=True,
70
+ tokenize=False
71
+ )
72
+
73
+ chat = [
74
+ {"role": "system", "content": "You are a helpful model trained by Awarri AI Technologies."},
75
+ {"role": "user", "content": user_msg}
76
+ ]
77
+
78
+ final_text = format_prompt(chat)
79
+ inputs = tokenizer(final_text, return_tensors="pt").to(model.device)
80
+
81
+ with torch.no_grad():
82
+ output_ids = model.generate(
83
+ **inputs,
84
+ max_new_tokens=200,
85
+ temperature=0.1,
86
+ repetition_penalty=1.12
87
+ )
88
+
89
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
90
+ return response
91
+
92
+
93
+ # --------------------------
94
+ # SPEECH FUNCTION
95
+ # --------------------------
96
+ def speechonly(speech, output_wav_path="response.wav"):
97
+ # --- ASR ---
98
+ inputs = processor(audios=speech, sampling_rate=16000, return_tensors="pt").to(device)
99
+ with torch.no_grad():
100
+ predicted_ids = asr_model.generate(inputs["input_features"], max_new_tokens=300)
101
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
102
+
103
+ # --- LLM Response ---
104
+ def format_prompt(messages):
105
+ return tokenizer.apply_chat_template(
106
+ messages,
107
+ add_generation_prompt=True,
108
+ tokenize=False
109
+ )
110
+
111
+ chat = [
112
+ {"role": "system", "content": "Respond ONLY in the detected Nigerian language (Yoruba, Igbo, Hausa, Pidgin, English)."},
113
+ {"role": "user", "content": transcription}
114
+ ]
115
+
116
+ final_text = format_prompt(chat)
117
+ inputs_llm = tokenizer(final_text, return_tensors="pt").to(model.device)
118
+
119
+ with torch.no_grad():
120
+ output_ids = model.generate(
121
+ **inputs_llm,
122
+ max_new_tokens=200,
123
+ temperature=0.1,
124
+ repetition_penalty=1.12
125
+ )
126
+
127
+ llm_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
128
+
129
+ # --- Detect language ---
130
+ lang_prompt = [
131
+ {"role": "system", "content": "You are a Nigerian language expert."},
132
+ {"role": "user", "content": f"In which Nigerian language is this text: '{llm_response}'? Reply with only one of these: Yoruba, Igbo, Hausa, Pidgin, English."}
133
+ ]
134
+ lang_text = format_prompt(lang_prompt)
135
+ lang_inputs = tokenizer(lang_text, return_tensors="pt").to(model.device)
136
+
137
+ with torch.no_grad():
138
+ lang_output_ids = model.generate(**lang_inputs, max_new_tokens=10)
139
+
140
+ llm_language = tokenizer.decode(lang_output_ids[0], skip_special_tokens=True).strip().lower()
141
+ if llm_language not in tts_models:
142
+ llm_language = "yoruba"
143
+
144
+ # --- TTS ---
145
+ tts_processor = tts_models[llm_language]["processor"]
146
+ tts_model = tts_models[llm_language]["model"]
147
+
148
+ tts_inputs = tts_processor(text=llm_response, return_tensors="pt").to(device)
149
+ with torch.no_grad():
150
+ output = tts_model(**tts_inputs)
151
+
152
+ # Extract waveform and save
153
+ audio_array = output.waveform.squeeze().cpu().numpy()
154
+ sf.write(output_wav_path, audio_array, 16000)
155
+
156
+ return llm_response, output_wav_path
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pydub
4
+ librosa
5
+ soundfile
6
+ transformers
7
+ torch
8
+ accelerate
9
+ bitsandbytes