drrobot9 commited on
Commit
4480348
·
1 Parent(s): 527b3c5

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +44 -158
app/agents/crew_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- # farmlingua/app/agents/crew_pipeline.py
2
  import os
3
  import sys
4
  import re
@@ -10,12 +10,13 @@ import numpy as np
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
- from app.utils.memory import memory_store
17
  from typing import List
18
 
 
19
  hf_cache = "/models/huggingface"
20
  os.environ["HF_HOME"] = hf_cache
21
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
@@ -28,11 +29,13 @@ if BASE_DIR not in sys.path:
28
 
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
 
 
31
  try:
32
  classifier = joblib.load(config.CLASSIFIER_PATH)
33
  except Exception:
34
  classifier = None
35
 
 
36
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
37
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
38
  model = AutoModelForCausalLM.from_pretrained(
@@ -41,8 +44,10 @@ model = AutoModelForCausalLM.from_pretrained(
41
  device_map="auto"
42
  )
43
 
 
44
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
45
 
 
46
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
47
  lang_model_path = hf_hub_download(
48
  repo_id=config.LANG_ID_MODEL_REPO,
@@ -57,56 +62,14 @@ def detect_language(text: str, top_k: int = 1):
57
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
58
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
59
 
 
60
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
61
-
62
- NLLB_MODEL = "facebook/nllb-200-distilled-600M"
63
- print(f"Using model: {NLLB_MODEL}")
64
-
65
- try:
66
- translation_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL)
67
- translation_model = AutoModelForSeq2SeqLM.from_pretrained(
68
- NLLB_MODEL,
69
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
70
- device_map="auto" if DEVICE == "cuda" else None
71
- )
72
-
73
- print(f"✓ Translation model loaded successfully")
74
-
75
- # DEBUG: Check tokenizer properties
76
- print(f"Tokenizer type: {type(translation_tokenizer).__name__}")
77
- print(f"Has lang_code_to_id: {hasattr(translation_tokenizer, 'lang_code_to_id')}")
78
-
79
- if hasattr(translation_tokenizer, 'lang_code_to_id'):
80
- print(f"Sample language codes: {list(translation_tokenizer.lang_code_to_id.keys())[:10]}")
81
- else:
82
-
83
- from transformers import AutoConfig
84
- config_model = AutoConfig.from_pretrained(NLLB_MODEL)
85
- print(f"Model config: {config_model}")
86
-
87
- except Exception as e:
88
- print(f"✗ Error loading translation model: {e}")
89
- raise
90
-
91
- # Language code mapping
92
- LANG_CODE_MAP = {
93
- "eng_Latn": "eng_Latn", # English
94
- "ibo_Latn": "ibo_Latn", # Igbo
95
- "yor_Latn": "yor_Latn", # Yoruba
96
- "hau_Latn": "hau_Latn", # Hausa
97
- "swh_Latn": "swa_Latn", # Swahili
98
- "amh_Latn": "amh_Ethi", # Amharic
99
- }
100
-
101
- # Alternative mapping k
102
- LANG_CODE_MAP_ALT = {
103
- "eng_Latn": "en", # English
104
- "ibo_Latn": "ig", # Igbo
105
- "yor_Latn": "yo", # Yoruba
106
- "hau_Latn": "ha", # Hausa
107
- "swh_Latn": "sw", # Swahili
108
- "amh_Latn": "am", # Amharic
109
- }
110
 
111
  SUPPORTED_LANGS = {
112
  "eng_Latn": "English",
@@ -117,6 +80,7 @@ SUPPORTED_LANGS = {
117
  "amh_Latn": "Amharic",
118
  }
119
 
 
120
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
121
 
122
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
@@ -138,108 +102,16 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
138
  return chunks
139
 
140
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
141
- print(f"\n[TRANSLATION] {src_lang} → {tgt_lang}")
142
- print(f" Input: {text[:100]}...")
143
-
144
- if not text.strip() or src_lang == tgt_lang:
145
- print(" No translation needed (same language)")
146
  return text
147
-
148
-
149
- src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
150
- tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
151
-
152
- print(f" Using codes: {src_code} → {tgt_code}")
153
-
154
-
155
- if not hasattr(translation_tokenizer, 'lang_code_to_id'):
156
- print(" WARNING: Tokenizer doesn't have lang_code_to_id")
157
- print(" Trying alternative method...")
158
-
159
-
160
- src_code_alt = LANG_CODE_MAP_ALT.get(src_lang, "en")
161
- tgt_code_alt = LANG_CODE_MAP_ALT.get(tgt_lang, "en")
162
-
163
-
164
- try:
165
- from transformers import pipeline
166
- translator = pipeline(
167
- "translation",
168
- model=translation_model,
169
- tokenizer=translation_tokenizer,
170
- src_lang=src_code_alt,
171
- tgt_lang=tgt_code_alt,
172
- device=0 if DEVICE == "cuda" else -1,
173
- max_length=400
174
- )
175
-
176
- chunks = chunk_text(text, max_len=max_chunk_len)
177
- translated_parts = []
178
-
179
- for chunk in chunks:
180
- result = translator(chunk)
181
- translated_parts.append(result[0]["translation_text"])
182
-
183
- return " ".join(translated_parts).strip()
184
-
185
- except Exception as e:
186
- print(f" Pipeline translation failed: {e}")
187
- return text
188
-
189
-
190
- if src_code not in translation_tokenizer.lang_code_to_id:
191
- print(f" WARNING: Source code {src_code} not found, trying alternatives...")
192
-
193
- src_code = LANG_CODE_MAP_ALT.get(src_lang, "eng_Latn")
194
-
195
- if tgt_code not in translation_tokenizer.lang_code_to_id:
196
- print(f" WARNING: Target code {tgt_code} not found, trying alternatives...")
197
- tgt_code = LANG_CODE_MAP_ALT.get(tgt_lang, "eng_Latn")
198
-
199
-
200
- translation_tokenizer.src_lang = src_code
201
-
202
-
203
- forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
204
-
205
  chunks = chunk_text(text, max_len=max_chunk_len)
206
  translated_parts = []
207
-
208
- for i, chunk in enumerate(chunks):
209
- try:
210
- inputs = translation_tokenizer(
211
- chunk,
212
- return_tensors="pt",
213
- truncation=True,
214
- max_length=512
215
- )
216
-
217
- if DEVICE == "cuda":
218
- inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
219
-
220
- generated_tokens = translation_model.generate(
221
- **inputs,
222
- forced_bos_token_id=forced_bos_token_id,
223
- max_new_tokens=400,
224
- num_beams=4,
225
- early_stopping=True
226
- )
227
-
228
- result = translation_tokenizer.batch_decode(
229
- generated_tokens,
230
- skip_special_tokens=True
231
- )[0]
232
-
233
- print(f" Chunk {i+1}: '{chunk[:30]}...' → '{result[:30]}...'")
234
- translated_parts.append(result.strip())
235
-
236
- except Exception as e:
237
- print(f" Chunk {i+1} error: {e}")
238
- translated_parts.append(chunk)
239
-
240
  return " ".join(translated_parts).strip()
241
 
242
-
243
  def retrieve_docs(query: str, vs_path: str):
244
  if not vs_path or not os.path.exists(vs_path):
245
  return None
@@ -259,6 +131,7 @@ def retrieve_docs(query: str, vs_path: str):
259
  return "\n\n".join(docs) if docs else None
260
  return None
261
 
 
262
  def get_weather(state_name: str) -> str:
263
  url = "http://api.weatherapi.com/v1/current.json"
264
  params = {"key": config.WEATHER_API_KEY, "q": f"{state_name}, Nigeria", "aqi": "no"}
@@ -274,6 +147,7 @@ def get_weather(state_name: str) -> str:
274
  f"- Wind: {data['current']['wind_kph']} kph"
275
  )
276
 
 
277
  def detect_intent(query: str):
278
  q_lower = (query or "").lower()
279
  if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]):
@@ -296,6 +170,7 @@ def detect_intent(query: str):
296
  pass
297
  return "normal", None
298
 
 
299
  def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
300
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
301
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
@@ -308,6 +183,7 @@ def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
308
  output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
309
  return tokenizer.decode(output_ids, skip_special_tokens=True).strip()
310
 
 
311
  MAX_HISTORY_MESSAGES = getattr(config, "MAX_HISTORY_MESSAGES", 30)
312
 
313
  def build_messages_from_history(history: List[dict], system_prompt: str) -> List[dict]:
@@ -315,7 +191,11 @@ def build_messages_from_history(history: List[dict], system_prompt: str) -> List
315
  msgs.extend(history)
316
  return msgs
317
 
 
318
  def strip_markdown(text: str) -> str:
 
 
 
319
  if not text:
320
  return ""
321
  text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
@@ -324,40 +204,47 @@ def strip_markdown(text: str) -> str:
324
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
325
  return text
326
 
 
327
  def run_pipeline(user_query: str, session_id: str = None):
 
 
 
 
328
  if session_id is None:
329
- session_id = str(uuid.uuid4())
330
 
 
331
  lang_label, prob = detect_language(user_query, top_k=1)[0]
332
  if lang_label not in SUPPORTED_LANGS:
333
  lang_label = "eng_Latn"
334
-
335
- print(f"Detected language: {SUPPORTED_LANGS.get(lang_label, 'Unknown')}")
336
 
337
  translated_query = (
338
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
339
  if lang_label != "eng_Latn"
340
  else user_query
341
  )
342
-
343
- print(f"Translated to English: {translated_query[:100]}...")
344
 
345
  intent, extra = detect_intent(translated_query)
346
 
 
347
  history = memory_store.get_history(session_id) or []
348
  if len(history) > MAX_HISTORY_MESSAGES:
349
  history = history[-MAX_HISTORY_MESSAGES:]
350
 
 
351
  history.append({"role": "user", "content": translated_query})
352
 
 
353
  system_prompt = (
354
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
355
  "Answer directly without repeating the question. "
356
  "Use clear farmer-friendly English with emojis . "
357
  "Avoid jargon and irrelevant details. "
358
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
 
359
  )
360
 
 
361
  if intent == "weather" and extra:
362
  weather_text = get_weather(extra)
363
  history.append({"role": "user", "content": f"Rewrite this weather update simply for farmers:\n{weather_text}"})
@@ -376,22 +263,21 @@ def run_pipeline(user_query: str, session_id: str = None):
376
  messages_for_qwen = build_messages_from_history(history, system_prompt)
377
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
378
 
 
379
  history.append({"role": "assistant", "content": english_answer})
380
  if len(history) > MAX_HISTORY_MESSAGES:
381
  history = history[-MAX_HISTORY_MESSAGES:]
382
  memory_store.save_history(session_id, history)
383
 
 
384
  final_answer = (
385
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
386
  if lang_label != "eng_Latn"
387
  else english_answer
388
  )
389
  final_answer = strip_markdown(final_answer)
390
-
391
- print(f"Final answer: {final_answer[:100]}...")
392
-
393
  return {
394
  "session_id": session_id,
395
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),
396
  "answer": final_answer
397
- }
 
1
+ # farmlingua/app/agents/crew_pipeline.pymemorysection
2
  import os
3
  import sys
4
  import re
 
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
+ from app.utils.memory import memory_store # memory module
17
  from typing import List
18
 
19
+
20
  hf_cache = "/models/huggingface"
21
  os.environ["HF_HOME"] = hf_cache
22
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
 
29
 
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+
33
  try:
34
  classifier = joblib.load(config.CLASSIFIER_PATH)
35
  except Exception:
36
  classifier = None
37
 
38
+
39
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
40
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
41
  model = AutoModelForCausalLM.from_pretrained(
 
44
  device_map="auto"
45
  )
46
 
47
+
48
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
49
 
50
+ # language detector
51
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
52
  lang_model_path = hf_hub_download(
53
  repo_id=config.LANG_ID_MODEL_REPO,
 
62
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
63
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
64
 
65
+ # Translation model
66
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
67
+ translation_pipeline = pipeline(
68
+ "translation",
69
+ model=config.TRANSLATION_MODEL_NAME,
70
+ device=0 if DEVICE == "cuda" else -1,
71
+ max_new_tokens=400,
72
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  SUPPORTED_LANGS = {
75
  "eng_Latn": "English",
 
80
  "amh_Latn": "Amharic",
81
  }
82
 
83
+ # Text chunking
84
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
85
 
86
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
 
102
  return chunks
103
 
104
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
105
+ if not text.strip():
 
 
 
 
106
  return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  chunks = chunk_text(text, max_len=max_chunk_len)
108
  translated_parts = []
109
+ for chunk in chunks:
110
+ res = translation_pipeline(chunk, src_lang=src_lang, tgt_lang=tgt_lang)
111
+ translated_parts.append(res[0]["translation_text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return " ".join(translated_parts).strip()
113
 
114
+ # RAG retrieval
115
  def retrieve_docs(query: str, vs_path: str):
116
  if not vs_path or not os.path.exists(vs_path):
117
  return None
 
131
  return "\n\n".join(docs) if docs else None
132
  return None
133
 
134
+
135
  def get_weather(state_name: str) -> str:
136
  url = "http://api.weatherapi.com/v1/current.json"
137
  params = {"key": config.WEATHER_API_KEY, "q": f"{state_name}, Nigeria", "aqi": "no"}
 
147
  f"- Wind: {data['current']['wind_kph']} kph"
148
  )
149
 
150
+
151
  def detect_intent(query: str):
152
  q_lower = (query or "").lower()
153
  if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]):
 
170
  pass
171
  return "normal", None
172
 
173
+ # expert runner
174
  def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
175
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
176
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
183
  output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
184
  return tokenizer.decode(output_ids, skip_special_tokens=True).strip()
185
 
186
+ # Memory
187
  MAX_HISTORY_MESSAGES = getattr(config, "MAX_HISTORY_MESSAGES", 30)
188
 
189
  def build_messages_from_history(history: List[dict], system_prompt: str) -> List[dict]:
 
191
  msgs.extend(history)
192
  return msgs
193
 
194
+
195
  def strip_markdown(text: str) -> str:
196
+ """
197
+ Remove Markdown formatting like **bold**, *italic*, and `inline code`.
198
+ """
199
  if not text:
200
  return ""
201
  text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
 
204
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
205
  return text
206
 
207
+ # Main pipeline
208
  def run_pipeline(user_query: str, session_id: str = None):
209
+ """
210
+ Run FarmLingua pipeline with per-session memory.
211
+ Each session_id keeps its own history.
212
+ """
213
  if session_id is None:
214
+ session_id = str(uuid.uuid4()) # fallback unique session
215
 
216
+ # Language detection
217
  lang_label, prob = detect_language(user_query, top_k=1)[0]
218
  if lang_label not in SUPPORTED_LANGS:
219
  lang_label = "eng_Latn"
 
 
220
 
221
  translated_query = (
222
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
223
  if lang_label != "eng_Latn"
224
  else user_query
225
  )
 
 
226
 
227
  intent, extra = detect_intent(translated_query)
228
 
229
+ # Load conversation history
230
  history = memory_store.get_history(session_id) or []
231
  if len(history) > MAX_HISTORY_MESSAGES:
232
  history = history[-MAX_HISTORY_MESSAGES:]
233
 
234
+
235
  history.append({"role": "user", "content": translated_query})
236
 
237
+
238
  system_prompt = (
239
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
240
  "Answer directly without repeating the question. "
241
  "Use clear farmer-friendly English with emojis . "
242
  "Avoid jargon and irrelevant details. "
243
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
244
+
245
  )
246
 
247
+
248
  if intent == "weather" and extra:
249
  weather_text = get_weather(extra)
250
  history.append({"role": "user", "content": f"Rewrite this weather update simply for farmers:\n{weather_text}"})
 
263
  messages_for_qwen = build_messages_from_history(history, system_prompt)
264
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
265
 
266
+ # Save assistant reply
267
  history.append({"role": "assistant", "content": english_answer})
268
  if len(history) > MAX_HISTORY_MESSAGES:
269
  history = history[-MAX_HISTORY_MESSAGES:]
270
  memory_store.save_history(session_id, history)
271
 
272
+ # Translate back if needed
273
  final_answer = (
274
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
275
  if lang_label != "eng_Latn"
276
  else english_answer
277
  )
278
  final_answer = strip_markdown(final_answer)
 
 
 
279
  return {
280
  "session_id": session_id,
281
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),
282
  "answer": final_answer
283
+ }