drrobot9 commited on
Commit
7e7d098
·
1 Parent(s): 6e2c1f0

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +127 -40
app/agents/crew_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- #farmlingua/app/agents/crew_pipeline.pymemorysection
2
  import os
3
  import sys
4
  import re
@@ -10,13 +10,12 @@ import numpy as np
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
- from app.utils.memory import memory_store # memory module
17
  from typing import List
18
 
19
-
20
  hf_cache = "/models/huggingface"
21
  os.environ["HF_HOME"] = hf_cache
22
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
@@ -29,13 +28,11 @@ if BASE_DIR not in sys.path:
29
 
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
-
33
  try:
34
  classifier = joblib.load(config.CLASSIFIER_PATH)
35
  except Exception:
36
  classifier = None
37
 
38
-
39
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
40
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
41
  model = AutoModelForCausalLM.from_pretrained(
@@ -44,10 +41,9 @@ model = AutoModelForCausalLM.from_pretrained(
44
  device_map="auto"
45
  )
46
 
47
-
48
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
49
 
50
- # language detector
51
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
52
  lang_model_path = hf_hub_download(
53
  repo_id=config.LANG_ID_MODEL_REPO,
@@ -62,15 +58,29 @@ def detect_language(text: str, top_k: int = 1):
62
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
63
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
64
 
65
- # Translation model
66
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
67
- translation_pipeline = pipeline(
68
- "translation_en_to_fr",
69
- model=config.TRANSLATION_MODEL_NAME,
70
- device=0 if DEVICE == "cuda" else -1,
71
- max_new_tokens=400,
 
 
72
  )
73
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  SUPPORTED_LANGS = {
75
  "eng_Latn": "English",
76
  "ibo_Latn": "Igbo",
@@ -102,16 +112,109 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
102
  return chunks
103
 
104
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
105
- if not text.strip():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  return text
 
107
  chunks = chunk_text(text, max_len=max_chunk_len)
108
  translated_parts = []
 
109
  for chunk in chunks:
110
- res = translation_pipeline(chunk, src_lang=src_lang, tgt_lang=tgt_lang)
111
- translated_parts.append(res[0]["translation_text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return " ".join(translated_parts).strip()
113
 
114
- # RAG retrieval
115
  def retrieve_docs(query: str, vs_path: str):
116
  if not vs_path or not os.path.exists(vs_path):
117
  return None
@@ -131,7 +234,6 @@ def retrieve_docs(query: str, vs_path: str):
131
  return "\n\n".join(docs) if docs else None
132
  return None
133
 
134
-
135
  def get_weather(state_name: str) -> str:
136
  url = "http://api.weatherapi.com/v1/current.json"
137
  params = {"key": config.WEATHER_API_KEY, "q": f"{state_name}, Nigeria", "aqi": "no"}
@@ -147,7 +249,6 @@ def get_weather(state_name: str) -> str:
147
  f"- Wind: {data['current']['wind_kph']} kph"
148
  )
149
 
150
-
151
  def detect_intent(query: str):
152
  q_lower = (query or "").lower()
153
  if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]):
@@ -170,7 +271,6 @@ def detect_intent(query: str):
170
  pass
171
  return "normal", None
172
 
173
- # expert runner
174
  def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
175
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
176
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
@@ -183,7 +283,6 @@ def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
183
  output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
184
  return tokenizer.decode(output_ids, skip_special_tokens=True).strip()
185
 
186
- # Memory
187
  MAX_HISTORY_MESSAGES = getattr(config, "MAX_HISTORY_MESSAGES", 30)
188
 
189
  def build_messages_from_history(history: List[dict], system_prompt: str) -> List[dict]:
@@ -191,11 +290,7 @@ def build_messages_from_history(history: List[dict], system_prompt: str) -> List
191
  msgs.extend(history)
192
  return msgs
193
 
194
-
195
  def strip_markdown(text: str) -> str:
196
- """
197
- Remove Markdown formatting like **bold**, *italic*, and `inline code`.
198
- """
199
  if not text:
200
  return ""
201
  text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
@@ -204,20 +299,16 @@ def strip_markdown(text: str) -> str:
204
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
205
  return text
206
 
207
- # Main pipeline
208
  def run_pipeline(user_query: str, session_id: str = None):
209
- """
210
- Run FarmLingua pipeline with per-session memory.
211
- Each session_id keeps its own history.
212
- """
213
  if session_id is None:
214
- session_id = str(uuid.uuid4()) # fallback unique session
215
 
216
  # Language detection
217
  lang_label, prob = detect_language(user_query, top_k=1)[0]
218
  if lang_label not in SUPPORTED_LANGS:
219
  lang_label = "eng_Latn"
220
 
 
221
  translated_query = (
222
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
223
  if lang_label != "eng_Latn"
@@ -226,25 +317,21 @@ def run_pipeline(user_query: str, session_id: str = None):
226
 
227
  intent, extra = detect_intent(translated_query)
228
 
229
- # Load conversation history
230
  history = memory_store.get_history(session_id) or []
231
  if len(history) > MAX_HISTORY_MESSAGES:
232
  history = history[-MAX_HISTORY_MESSAGES:]
233
 
234
-
235
  history.append({"role": "user", "content": translated_query})
236
 
237
-
238
  system_prompt = (
239
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
240
  "Answer directly without repeating the question. "
241
  "Use clear farmer-friendly English with emojis . "
242
  "Avoid jargon and irrelevant details. "
243
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
244
-
245
  )
246
 
247
-
248
  if intent == "weather" and extra:
249
  weather_text = get_weather(extra)
250
  history.append({"role": "user", "content": f"Rewrite this weather update simply for farmers:\n{weather_text}"})
@@ -263,13 +350,13 @@ def run_pipeline(user_query: str, session_id: str = None):
263
  messages_for_qwen = build_messages_from_history(history, system_prompt)
264
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
265
 
266
- # Save assistant reply
267
  history.append({"role": "assistant", "content": english_answer})
268
  if len(history) > MAX_HISTORY_MESSAGES:
269
  history = history[-MAX_HISTORY_MESSAGES:]
270
  memory_store.save_history(session_id, history)
271
 
272
- # Translate back if needed
273
  final_answer = (
274
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
275
  if lang_label != "eng_Latn"
@@ -280,4 +367,4 @@ def run_pipeline(user_query: str, session_id: str = None):
280
  "session_id": session_id,
281
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),
282
  "answer": final_answer
283
- }
 
1
+ # farmlingua/app/agents/crew_pipeline.py
2
  import os
3
  import sys
4
  import re
 
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
+ from app.utils.memory import memory_store
17
  from typing import List
18
 
 
19
  hf_cache = "/models/huggingface"
20
  os.environ["HF_HOME"] = hf_cache
21
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
 
28
 
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
 
 
31
  try:
32
  classifier = joblib.load(config.CLASSIFIER_PATH)
33
  except Exception:
34
  classifier = None
35
 
 
36
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
37
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
38
  model = AutoModelForCausalLM.from_pretrained(
 
41
  device_map="auto"
42
  )
43
 
 
44
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
45
 
46
+ # Language detector
47
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
48
  lang_model_path = hf_hub_download(
49
  repo_id=config.LANG_ID_MODEL_REPO,
 
58
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
59
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
60
 
61
+
62
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
63
+
64
+
65
+ translation_tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
66
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
67
+ config.TRANSLATION_MODEL_NAME,
68
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
69
+ device_map="auto" if DEVICE == "cuda" else None
70
  )
71
 
72
+
73
+ LANG_CODE_MAP = {
74
+ "eng_Latn": "eng_Latn", # English
75
+ "ibo_Latn": "ibo_Latn", # Igbo
76
+ "yor_Latn": "yor_Latn", # Yoruba
77
+ "hau_Latn": "hau_Latn", # Hausa
78
+ "swh_Latn": "swh_Latn", # Swahili
79
+ "amh_Latn": "amh_Latn", # Amharic
80
+ }
81
+
82
+
83
+
84
  SUPPORTED_LANGS = {
85
  "eng_Latn": "English",
86
  "ibo_Latn": "Igbo",
 
112
  return chunks
113
 
114
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
115
+ """
116
+ Translate text using the custom NLLB model directly
117
+ """
118
+ if not text.strip() or src_lang == tgt_lang:
119
+ return text
120
+
121
+ # Get language codes
122
+ src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
123
+ tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
124
+
125
+
126
+ if not hasattr(translation_tokenizer, 'lang_code_to_id'):
127
+ print("Warning: Tokenizer doesn't have lang_code_to_id attribute")
128
+ print(f"Available tokenizer special tokens: {translation_tokenizer.special_tokens_map}")
129
+
130
+ return translate_text_simple(text, src_lang, tgt_lang, max_chunk_len)
131
+
132
+
133
+ if src_code not in translation_tokenizer.lang_code_to_id:
134
+ print(f"Warning: Source language code '{src_code}' not found in tokenizer")
135
+ src_code = "eng_Latn"
136
+
137
+ if tgt_code not in translation_tokenizer.lang_code_to_id:
138
+ print(f"Warning: Target language code '{tgt_code}' not found in tokenizer")
139
+ tgt_code = "eng_Latn"
140
+
141
+
142
+ translation_tokenizer.src_lang = src_code
143
+
144
+
145
+ forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
146
+
147
+ chunks = chunk_text(text, max_len=max_chunk_len)
148
+ translated_parts = []
149
+
150
+ for chunk in chunks:
151
+ try:
152
+
153
+ inputs = translation_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
154
+
155
+ if DEVICE == "cuda":
156
+ inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
157
+
158
+ # Generate translation
159
+ generated_tokens = translation_model.generate(
160
+ **inputs,
161
+ forced_bos_token_id=forced_bos_token_id,
162
+ max_new_tokens=400,
163
+ num_beams=4,
164
+ early_stopping=True
165
+ )
166
+
167
+ # Decode
168
+ result = translation_tokenizer.batch_decode(
169
+ generated_tokens,
170
+ skip_special_tokens=True
171
+ )[0]
172
+
173
+ translated_parts.append(result)
174
+
175
+ except Exception as e:
176
+ print(f"Translation error ({src_code}->{tgt_code}): {e}")
177
+
178
+ translated_parts.append(chunk)
179
+
180
+ return " ".join(translated_parts).strip()
181
+
182
+ def translate_text_simple(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
183
+ """
184
+ Simple fallback translation function if the main one fails
185
+ """
186
+ if not text.strip() or src_lang == tgt_lang:
187
  return text
188
+
189
  chunks = chunk_text(text, max_len=max_chunk_len)
190
  translated_parts = []
191
+
192
  for chunk in chunks:
193
+ try:
194
+
195
+ inputs = translation_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
196
+
197
+ if DEVICE == "cuda":
198
+ inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
199
+
200
+ generated_tokens = translation_model.generate(
201
+ **inputs,
202
+ max_new_tokens=400
203
+ )
204
+
205
+ result = translation_tokenizer.batch_decode(
206
+ generated_tokens,
207
+ skip_special_tokens=True
208
+ )[0]
209
+
210
+ translated_parts.append(result)
211
+ except Exception as e:
212
+ print(f"Simple translation error: {e}")
213
+ translated_parts.append(chunk)
214
+
215
  return " ".join(translated_parts).strip()
216
 
217
+ # RAG retrieval
218
  def retrieve_docs(query: str, vs_path: str):
219
  if not vs_path or not os.path.exists(vs_path):
220
  return None
 
234
  return "\n\n".join(docs) if docs else None
235
  return None
236
 
 
237
  def get_weather(state_name: str) -> str:
238
  url = "http://api.weatherapi.com/v1/current.json"
239
  params = {"key": config.WEATHER_API_KEY, "q": f"{state_name}, Nigeria", "aqi": "no"}
 
249
  f"- Wind: {data['current']['wind_kph']} kph"
250
  )
251
 
 
252
  def detect_intent(query: str):
253
  q_lower = (query or "").lower()
254
  if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]):
 
271
  pass
272
  return "normal", None
273
 
 
274
  def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
275
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
276
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
283
  output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
284
  return tokenizer.decode(output_ids, skip_special_tokens=True).strip()
285
 
 
286
  MAX_HISTORY_MESSAGES = getattr(config, "MAX_HISTORY_MESSAGES", 30)
287
 
288
  def build_messages_from_history(history: List[dict], system_prompt: str) -> List[dict]:
 
290
  msgs.extend(history)
291
  return msgs
292
 
 
293
  def strip_markdown(text: str) -> str:
 
 
 
294
  if not text:
295
  return ""
296
  text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
 
299
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
300
  return text
301
 
 
302
  def run_pipeline(user_query: str, session_id: str = None):
 
 
 
 
303
  if session_id is None:
304
+ session_id = str(uuid.uuid4())
305
 
306
  # Language detection
307
  lang_label, prob = detect_language(user_query, top_k=1)[0]
308
  if lang_label not in SUPPORTED_LANGS:
309
  lang_label = "eng_Latn"
310
 
311
+
312
  translated_query = (
313
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
314
  if lang_label != "eng_Latn"
 
317
 
318
  intent, extra = detect_intent(translated_query)
319
 
320
+ # Load conversation history
321
  history = memory_store.get_history(session_id) or []
322
  if len(history) > MAX_HISTORY_MESSAGES:
323
  history = history[-MAX_HISTORY_MESSAGES:]
324
 
 
325
  history.append({"role": "user", "content": translated_query})
326
 
 
327
  system_prompt = (
328
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
329
  "Answer directly without repeating the question. "
330
  "Use clear farmer-friendly English with emojis . "
331
  "Avoid jargon and irrelevant details. "
332
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
 
333
  )
334
 
 
335
  if intent == "weather" and extra:
336
  weather_text = get_weather(extra)
337
  history.append({"role": "user", "content": f"Rewrite this weather update simply for farmers:\n{weather_text}"})
 
350
  messages_for_qwen = build_messages_from_history(history, system_prompt)
351
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
352
 
353
+
354
  history.append({"role": "assistant", "content": english_answer})
355
  if len(history) > MAX_HISTORY_MESSAGES:
356
  history = history[-MAX_HISTORY_MESSAGES:]
357
  memory_store.save_history(session_id, history)
358
 
359
+
360
  final_answer = (
361
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
362
  if lang_label != "eng_Latn"
 
367
  "session_id": session_id,
368
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),
369
  "answer": final_answer
370
+ }