BissakaAI commited on
Commit
f26c5be
·
verified ·
1 Parent(s): f9a96bb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +156 -156
model.py CHANGED
@@ -1,156 +1,156 @@
1
- # your_model_file.py
2
- from transformers import (
3
- AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
4
- AutoProcessor, SeamlessM4Tv2ForSpeechToText,
5
- VitsModel
6
- )
7
- import torch
8
- import soundfile as sf
9
- import os
10
-
11
- # --------------------------
12
- # Device & config
13
- # --------------------------
14
- bnb_config = BitsAndBytesConfig(load_in_8bit=True)
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
-
17
- # --------------------------
18
- # Load LLM
19
- # --------------------------
20
- HF_TOKEN = os.getenv("HF_TOKEN") # Use environment variable for Spaces
21
-
22
- tokenizer = AutoTokenizer.from_pretrained(
23
- "NCAIR1/N-ATLaS",
24
- trust_remote_code=True,
25
- token=HF_TOKEN
26
- )
27
-
28
- model = AutoModelForCausalLM.from_pretrained(
29
- "NCAIR1/N-ATLaS",
30
- quantization_config=bnb_config,
31
- device_map="auto",
32
- trust_remote_code=True,
33
- token=HF_TOKEN
34
- )
35
-
36
- # --------------------------
37
- # Load ASR
38
- # --------------------------
39
- ASR_MODEL = "facebook/seamless-m4t-v2-large"
40
- processor = AutoProcessor.from_pretrained(ASR_MODEL, token=HF_TOKEN)
41
- asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(ASR_MODEL, token=HF_TOKEN).to(device)
42
- asr_model.eval()
43
-
44
- # --------------------------
45
- # Load Nigerian TTS models
46
- # --------------------------
47
- tts_models = {}
48
- for lang, tts_name in {
49
- "yoruba": "facebook/mms-tts-yor",
50
- # "igbo": "facebook/mms-tts-ibo",
51
- # "hausa": "facebook/mms-tts-hau",
52
- }.items():
53
- print(f"Loading TTS model for {lang}...")
54
- tts_proc = AutoProcessor.from_pretrained(tts_name, token=HF_TOKEN)
55
- tts_mod = VitsModel.from_pretrained(tts_name, token=HF_TOKEN).to(device)
56
- tts_mod.eval()
57
- tts_models[lang] = {"processor": tts_proc, "model": tts_mod}
58
-
59
- print("✅ All models loaded successfully!")
60
-
61
-
62
- # --------------------------
63
- # TEXT FUNCTION
64
- # --------------------------
65
- def textonly(user_msg: str) -> str:
66
- def format_prompt(messages):
67
- return tokenizer.apply_chat_template(
68
- messages,
69
- add_generation_prompt=True,
70
- tokenize=False
71
- )
72
-
73
- chat = [
74
- {"role": "system", "content": "You are a helpful model trained by Awarri AI Technologies."},
75
- {"role": "user", "content": user_msg}
76
- ]
77
-
78
- final_text = format_prompt(chat)
79
- inputs = tokenizer(final_text, return_tensors="pt").to(model.device)
80
-
81
- with torch.no_grad():
82
- output_ids = model.generate(
83
- **inputs,
84
- max_new_tokens=200,
85
- temperature=0.1,
86
- repetition_penalty=1.12
87
- )
88
-
89
- response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
90
- return response
91
-
92
-
93
- # --------------------------
94
- # SPEECH FUNCTION
95
- # --------------------------
96
- def speechonly(speech, output_wav_path="response.wav"):
97
- # --- ASR ---
98
- inputs = processor(audios=speech, sampling_rate=16000, return_tensors="pt").to(device)
99
- with torch.no_grad():
100
- predicted_ids = asr_model.generate(inputs["input_features"], max_new_tokens=300)
101
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
102
-
103
- # --- LLM Response ---
104
- def format_prompt(messages):
105
- return tokenizer.apply_chat_template(
106
- messages,
107
- add_generation_prompt=True,
108
- tokenize=False
109
- )
110
-
111
- chat = [
112
- {"role": "system", "content": "Respond ONLY in the detected Nigerian language (Yoruba, Igbo, Hausa, Pidgin, English)."},
113
- {"role": "user", "content": transcription}
114
- ]
115
-
116
- final_text = format_prompt(chat)
117
- inputs_llm = tokenizer(final_text, return_tensors="pt").to(model.device)
118
-
119
- with torch.no_grad():
120
- output_ids = model.generate(
121
- **inputs_llm,
122
- max_new_tokens=200,
123
- temperature=0.1,
124
- repetition_penalty=1.12
125
- )
126
-
127
- llm_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
128
-
129
- # --- Detect language ---
130
- lang_prompt = [
131
- {"role": "system", "content": "You are a Nigerian language expert."},
132
- {"role": "user", "content": f"In which Nigerian language is this text: '{llm_response}'? Reply with only one of these: Yoruba, Igbo, Hausa, Pidgin, English."}
133
- ]
134
- lang_text = format_prompt(lang_prompt)
135
- lang_inputs = tokenizer(lang_text, return_tensors="pt").to(model.device)
136
-
137
- with torch.no_grad():
138
- lang_output_ids = model.generate(**lang_inputs, max_new_tokens=10)
139
-
140
- llm_language = tokenizer.decode(lang_output_ids[0], skip_special_tokens=True).strip().lower()
141
- if llm_language not in tts_models:
142
- llm_language = "yoruba"
143
-
144
- # --- TTS ---
145
- tts_processor = tts_models[llm_language]["processor"]
146
- tts_model = tts_models[llm_language]["model"]
147
-
148
- tts_inputs = tts_processor(text=llm_response, return_tensors="pt").to(device)
149
- with torch.no_grad():
150
- output = tts_model(**tts_inputs)
151
-
152
- # Extract waveform and save
153
- audio_array = output.waveform.squeeze().cpu().numpy()
154
- sf.write(output_wav_path, audio_array, 16000)
155
-
156
- return llm_response, output_wav_path
 
1
+ # your_model_file.py
2
+ from transformers import (
3
+ AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
4
+ AutoProcessor, SeamlessM4Tv2ForSpeechToText,
5
+ VitsModel
6
+ )
7
+ import torch
8
+ import soundfile as sf
9
+ import os
10
+
11
+ # --------------------------
12
+ # Device & config
13
+ # --------------------------
14
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # --------------------------
18
+ # Load LLM
19
+ # --------------------------
20
+ HF_TOKEN = os.getenv("HF_TOKEN") # Use environment variable for Spaces
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ "NCAIR1/N-ATLaS",
24
+ trust_remote_code=True,
25
+ token=HF_TOKEN
26
+ )
27
+
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ "NCAIR1/N-ATLaS",
30
+ quantization_config=bnb_config,
31
+ device_map="auto",
32
+ trust_remote_code=True,
33
+ token=HF_TOKEN
34
+ )
35
+
36
+ # --------------------------
37
+ # Load ASR
38
+ # --------------------------
39
+ ASR_MODEL = "facebook/seamless-m4t-v2-large"
40
+ processor = AutoProcessor.from_pretrained(ASR_MODEL, token=HF_TOKEN)
41
+ asr_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(ASR_MODEL, token=HF_TOKEN, use_fast=False).to(device)
42
+ asr_model.eval()
43
+
44
+ # --------------------------
45
+ # Load Nigerian TTS models
46
+ # --------------------------
47
+ tts_models = {}
48
+ for lang, tts_name in {
49
+ "yoruba": "facebook/mms-tts-yor",
50
+ # "igbo": "facebook/mms-tts-ibo",
51
+ # "hausa": "facebook/mms-tts-hau",
52
+ }.items():
53
+ print(f"Loading TTS model for {lang}...")
54
+ tts_proc = AutoProcessor.from_pretrained(tts_name, token=HF_TOKEN)
55
+ tts_mod = VitsModel.from_pretrained(tts_name, token=HF_TOKEN).to(device)
56
+ tts_mod.eval()
57
+ tts_models[lang] = {"processor": tts_proc, "model": tts_mod}
58
+
59
+ print("✅ All models loaded successfully!")
60
+
61
+
62
+ # --------------------------
63
+ # TEXT FUNCTION
64
+ # --------------------------
65
+ def textonly(user_msg: str) -> str:
66
+ def format_prompt(messages):
67
+ return tokenizer.apply_chat_template(
68
+ messages,
69
+ add_generation_prompt=True,
70
+ tokenize=False
71
+ )
72
+
73
+ chat = [
74
+ {"role": "system", "content": "You are a helpful model trained by Awarri AI Technologies."},
75
+ {"role": "user", "content": user_msg}
76
+ ]
77
+
78
+ final_text = format_prompt(chat)
79
+ inputs = tokenizer(final_text, return_tensors="pt").to(model.device)
80
+
81
+ with torch.no_grad():
82
+ output_ids = model.generate(
83
+ **inputs,
84
+ max_new_tokens=200,
85
+ temperature=0.1,
86
+ repetition_penalty=1.12
87
+ )
88
+
89
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
90
+ return response
91
+
92
+
93
+ # --------------------------
94
+ # SPEECH FUNCTION
95
+ # --------------------------
96
+ def speechonly(speech, output_wav_path="response.wav"):
97
+ # --- ASR ---
98
+ inputs = processor(audios=speech, sampling_rate=16000, return_tensors="pt").to(device)
99
+ with torch.no_grad():
100
+ predicted_ids = asr_model.generate(inputs["input_features"], max_new_tokens=300)
101
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
102
+
103
+ # --- LLM Response ---
104
+ def format_prompt(messages):
105
+ return tokenizer.apply_chat_template(
106
+ messages,
107
+ add_generation_prompt=True,
108
+ tokenize=False
109
+ )
110
+
111
+ chat = [
112
+ {"role": "system", "content": "Respond ONLY in the detected Nigerian language (Yoruba, Igbo, Hausa, Pidgin, English)."},
113
+ {"role": "user", "content": transcription}
114
+ ]
115
+
116
+ final_text = format_prompt(chat)
117
+ inputs_llm = tokenizer(final_text, return_tensors="pt").to(model.device)
118
+
119
+ with torch.no_grad():
120
+ output_ids = model.generate(
121
+ **inputs_llm,
122
+ max_new_tokens=200,
123
+ temperature=0.1,
124
+ repetition_penalty=1.12
125
+ )
126
+
127
+ llm_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
128
+
129
+ # --- Detect language ---
130
+ lang_prompt = [
131
+ {"role": "system", "content": "You are a Nigerian language expert."},
132
+ {"role": "user", "content": f"In which Nigerian language is this text: '{llm_response}'? Reply with only one of these: Yoruba, Igbo, Hausa, Pidgin, English."}
133
+ ]
134
+ lang_text = format_prompt(lang_prompt)
135
+ lang_inputs = tokenizer(lang_text, return_tensors="pt").to(model.device)
136
+
137
+ with torch.no_grad():
138
+ lang_output_ids = model.generate(**lang_inputs, max_new_tokens=10)
139
+
140
+ llm_language = tokenizer.decode(lang_output_ids[0], skip_special_tokens=True).strip().lower()
141
+ if llm_language not in tts_models:
142
+ llm_language = "yoruba"
143
+
144
+ # --- TTS ---
145
+ tts_processor = tts_models[llm_language]["processor"]
146
+ tts_model = tts_models[llm_language]["model"]
147
+
148
+ tts_inputs = tts_processor(text=llm_response, return_tensors="pt").to(device)
149
+ with torch.no_grad():
150
+ output = tts_model(**tts_inputs)
151
+
152
+ # Extract waveform and save
153
+ audio_array = output.waveform.squeeze().cpu().numpy()
154
+ sf.write(output_wav_path, audio_array, 16000)
155
+
156
+ return llm_response, output_wav_path