drrobot9 commited on
Commit
4d564d9
·
1 Parent(s): edfc902

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +134 -45
app/agents/crew_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- # farmlingua/app/agents/crew_pipeline.pymemorysection
2
  import os
3
  import sys
4
  import re
@@ -10,13 +10,12 @@ import numpy as np
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
- from app.utils.memory import memory_store # memory module
17
  from typing import List
18
 
19
-
20
  hf_cache = "/models/huggingface"
21
  os.environ["HF_HOME"] = hf_cache
22
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
@@ -29,13 +28,11 @@ if BASE_DIR not in sys.path:
29
 
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
-
33
  try:
34
  classifier = joblib.load(config.CLASSIFIER_PATH)
35
  except Exception:
36
  classifier = None
37
 
38
-
39
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
40
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
41
  model = AutoModelForCausalLM.from_pretrained(
@@ -44,10 +41,8 @@ model = AutoModelForCausalLM.from_pretrained(
44
  device_map="auto"
45
  )
46
 
47
-
48
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
49
 
50
- # language detector
51
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
52
  lang_model_path = hf_hub_download(
53
  repo_id=config.LANG_ID_MODEL_REPO,
@@ -62,14 +57,16 @@ def detect_language(text: str, top_k: int = 1):
62
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
63
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
64
 
65
- # Translation model
66
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
67
- translation_pipeline = pipeline(
68
- "translation",
69
- model=config.TRANSLATION_MODEL_NAME,
70
- device=0 if DEVICE == "cuda" else -1,
71
- max_new_tokens=400,
72
- )
 
 
 
73
 
74
  SUPPORTED_LANGS = {
75
  "eng_Latn": "English",
@@ -80,7 +77,6 @@ SUPPORTED_LANGS = {
80
  "amh_Latn": "Amharic",
81
  }
82
 
83
- # Text chunking
84
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
85
 
86
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
@@ -101,17 +97,124 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
101
  chunks.append(current.strip())
102
  return chunks
103
 
104
- def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
105
- if not text.strip():
106
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  chunks = chunk_text(text, max_len=max_chunk_len)
108
  translated_parts = []
109
- for chunk in chunks:
110
- res = translation_pipeline(chunk, src_lang=src_lang, tgt_lang=tgt_lang)
111
- translated_parts.append(res[0]["translation_text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return " ".join(translated_parts).strip()
113
 
114
- # RAG retrieval
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def retrieve_docs(query: str, vs_path: str):
116
  if not vs_path or not os.path.exists(vs_path):
117
  return None
@@ -131,7 +234,6 @@ def retrieve_docs(query: str, vs_path: str):
131
  return "\n\n".join(docs) if docs else None
132
  return None
133
 
134
-
135
  def get_weather(state_name: str) -> str:
136
  url = "http://api.weatherapi.com/v1/current.json"
137
  params = {"key": config.WEATHER_API_KEY, "q": f"{state_name}, Nigeria", "aqi": "no"}
@@ -147,7 +249,6 @@ def get_weather(state_name: str) -> str:
147
  f"- Wind: {data['current']['wind_kph']} kph"
148
  )
149
 
150
-
151
  def detect_intent(query: str):
152
  q_lower = (query or "").lower()
153
  if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]):
@@ -170,7 +271,6 @@ def detect_intent(query: str):
170
  pass
171
  return "normal", None
172
 
173
- # expert runner
174
  def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
175
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
176
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
@@ -183,7 +283,6 @@ def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
183
  output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
184
  return tokenizer.decode(output_ids, skip_special_tokens=True).strip()
185
 
186
- # Memory
187
  MAX_HISTORY_MESSAGES = getattr(config, "MAX_HISTORY_MESSAGES", 30)
188
 
189
  def build_messages_from_history(history: List[dict], system_prompt: str) -> List[dict]:
@@ -191,11 +290,7 @@ def build_messages_from_history(history: List[dict], system_prompt: str) -> List
191
  msgs.extend(history)
192
  return msgs
193
 
194
-
195
  def strip_markdown(text: str) -> str:
196
- """
197
- Remove Markdown formatting like **bold**, *italic*, and `inline code`.
198
- """
199
  if not text:
200
  return ""
201
  text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
@@ -204,47 +299,40 @@ def strip_markdown(text: str) -> str:
204
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
205
  return text
206
 
207
- # Main pipeline
208
  def run_pipeline(user_query: str, session_id: str = None):
209
- """
210
- Run FarmLingua pipeline with per-session memory.
211
- Each session_id keeps its own history.
212
- """
213
  if session_id is None:
214
- session_id = str(uuid.uuid4()) # fallback unique session
215
 
216
- # Language detection
217
  lang_label, prob = detect_language(user_query, top_k=1)[0]
218
  if lang_label not in SUPPORTED_LANGS:
219
  lang_label = "eng_Latn"
 
 
220
 
221
  translated_query = (
222
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
223
  if lang_label != "eng_Latn"
224
  else user_query
225
  )
 
 
226
 
227
  intent, extra = detect_intent(translated_query)
228
 
229
- # Load conversation history
230
  history = memory_store.get_history(session_id) or []
231
  if len(history) > MAX_HISTORY_MESSAGES:
232
  history = history[-MAX_HISTORY_MESSAGES:]
233
 
234
-
235
  history.append({"role": "user", "content": translated_query})
236
 
237
-
238
  system_prompt = (
239
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
240
  "Answer directly without repeating the question. "
241
  "Use clear farmer-friendly English with emojis . "
242
  "Avoid jargon and irrelevant details. "
243
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
244
-
245
  )
246
 
247
-
248
  if intent == "weather" and extra:
249
  weather_text = get_weather(extra)
250
  history.append({"role": "user", "content": f"Rewrite this weather update simply for farmers:\n{weather_text}"})
@@ -263,19 +351,20 @@ def run_pipeline(user_query: str, session_id: str = None):
263
  messages_for_qwen = build_messages_from_history(history, system_prompt)
264
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
265
 
266
- # Save assistant reply
267
  history.append({"role": "assistant", "content": english_answer})
268
  if len(history) > MAX_HISTORY_MESSAGES:
269
  history = history[-MAX_HISTORY_MESSAGES:]
270
  memory_store.save_history(session_id, history)
271
 
272
- # Translate back if needed
273
  final_answer = (
274
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
275
  if lang_label != "eng_Latn"
276
  else english_answer
277
  )
278
  final_answer = strip_markdown(final_answer)
 
 
 
279
  return {
280
  "session_id": session_id,
281
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),
 
1
+ # farmlingua/app/agents/crew_pipeline.py
2
  import os
3
  import sys
4
  import re
 
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
+ from app.utils.memory import memory_store
17
  from typing import List
18
 
 
19
  hf_cache = "/models/huggingface"
20
  os.environ["HF_HOME"] = hf_cache
21
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
 
28
 
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
 
 
31
  try:
32
  classifier = joblib.load(config.CLASSIFIER_PATH)
33
  except Exception:
34
  classifier = None
35
 
 
36
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
37
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
38
  model = AutoModelForCausalLM.from_pretrained(
 
41
  device_map="auto"
42
  )
43
 
 
44
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
45
 
 
46
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
47
  lang_model_path = hf_hub_download(
48
  repo_id=config.LANG_ID_MODEL_REPO,
 
57
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
58
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
59
 
 
60
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
61
+
62
+ LANG_CODE_MAP = {
63
+ "eng_Latn": "eng_Latn",
64
+ "ibo_Latn": "ibo_Latn",
65
+ "yor_Latn": "yor_Latn",
66
+ "hau_Latn": "hau_Latn",
67
+ "swh_Latn": "swa_Latn",
68
+ "amh_Latn": "amh_Ethi",
69
+ }
70
 
71
  SUPPORTED_LANGS = {
72
  "eng_Latn": "English",
 
77
  "amh_Latn": "Amharic",
78
  }
79
 
 
80
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
81
 
82
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
 
97
  chunks.append(current.strip())
98
  return chunks
99
 
100
+ def load_translation_model():
101
+ """Load translation model with proper configuration"""
102
+ try:
103
+ tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
104
+ model = AutoModelForSeq2SeqLM.from_pretrained(
105
+ config.TRANSLATION_MODEL_NAME,
106
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
107
+ device_map="auto" if DEVICE == "cuda" else None
108
+ )
109
+ print("✓ Custom translation model loaded")
110
+ return tokenizer, model
111
+ except Exception as e:
112
+ print(f"✗ Error loading custom model: {e}")
113
+ print("Loading standard NLLB model as fallback...")
114
+ tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
115
+ model = AutoModelForSeq2SeqLM.from_pretrained(
116
+ "facebook/nllb-200-distilled-600M",
117
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
118
+ device_map="auto" if DEVICE == "cuda" else None
119
+ )
120
+ print("✓ Standard NLLB model loaded as fallback")
121
+ return tokenizer, model
122
+
123
+ # Load the model
124
+ translation_tokenizer, translation_model = load_translation_model()
125
+
126
+ def translate_with_nllb(text: str, src_code: str, tgt_code: str, max_chunk_len: int = 400) -> str:
127
+ """Translate using NLLB model with forced_bos_token_id"""
128
  chunks = chunk_text(text, max_len=max_chunk_len)
129
  translated_parts = []
130
+
131
+ # Check if tokenizer has lang_code_to_id
132
+ if hasattr(translation_tokenizer, 'lang_code_to_id'):
133
+ try:
134
+ # Set source language
135
+ translation_tokenizer.src_lang = src_code
136
+ # Get forced bos token ID
137
+ forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
138
+
139
+ for i, chunk in enumerate(chunks):
140
+ try:
141
+ inputs = translation_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
142
+
143
+ if DEVICE == "cuda":
144
+ inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
145
+
146
+ generated_tokens = translation_model.generate(
147
+ **inputs,
148
+ forced_bos_token_id=forced_bos_token_id,
149
+ max_new_tokens=400,
150
+ num_beams=4,
151
+ early_stopping=True
152
+ )
153
+
154
+ result = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
155
+ translated_parts.append(result.strip())
156
+
157
+ except Exception as e:
158
+ print(f" Chunk {i+1} error: {e}")
159
+ translated_parts.append(chunk)
160
+ except Exception as e:
161
+ print(f" Language code error: {e}")
162
+ # Fallback to simple translation
163
+ return translate_simple(text, max_chunk_len)
164
+ else:
165
+ # If no lang_code_to_id, try simple translation
166
+ return translate_simple(text, max_chunk_len)
167
+
168
  return " ".join(translated_parts).strip()
169
 
170
+ def translate_simple(text: str, max_chunk_len: int = 400) -> str:
171
+ """Simple translation without language codes"""
172
+ chunks = chunk_text(text, max_len=max_chunk_len)
173
+ translated_parts = []
174
+
175
+ for i, chunk in enumerate(chunks):
176
+ try:
177
+ inputs = translation_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
178
+
179
+ if DEVICE == "cuda":
180
+ inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
181
+
182
+ generated_tokens = translation_model.generate(
183
+ **inputs,
184
+ max_new_tokens=400,
185
+ num_beams=4,
186
+ early_stopping=True
187
+ )
188
+
189
+ result = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
190
+ translated_parts.append(result.strip())
191
+
192
+ except Exception as e:
193
+ print(f" Chunk {i+1} error: {e}")
194
+ translated_parts.append(chunk)
195
+
196
+ return " ".join(translated_parts).strip()
197
+
198
+ def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
199
+ print(f"\n[TRANSLATION] {src_lang} → {tgt_lang}")
200
+ print(f" Input: {text[:100]}...")
201
+
202
+ if not text.strip() or src_lang == tgt_lang:
203
+ print(" No translation needed (same language)")
204
+ return text
205
+
206
+ src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
207
+ tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
208
+
209
+ print(f" Using codes: {src_code} → {tgt_code}")
210
+
211
+ if src_code != "eng_Latn" and tgt_code != "eng_Latn":
212
+ print(f" Two-step translation: {src_code}→eng_Latn→{tgt_code}")
213
+ to_english = translate_with_nllb(text, src_code, "eng_Latn", max_chunk_len)
214
+ return translate_with_nllb(to_english, "eng_Latn", tgt_code, max_chunk_len)
215
+
216
+ return translate_with_nllb(text, src_code, tgt_code, max_chunk_len)
217
+
218
  def retrieve_docs(query: str, vs_path: str):
219
  if not vs_path or not os.path.exists(vs_path):
220
  return None
 
234
  return "\n\n".join(docs) if docs else None
235
  return None
236
 
 
237
  def get_weather(state_name: str) -> str:
238
  url = "http://api.weatherapi.com/v1/current.json"
239
  params = {"key": config.WEATHER_API_KEY, "q": f"{state_name}, Nigeria", "aqi": "no"}
 
249
  f"- Wind: {data['current']['wind_kph']} kph"
250
  )
251
 
 
252
  def detect_intent(query: str):
253
  q_lower = (query or "").lower()
254
  if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]):
 
271
  pass
272
  return "normal", None
273
 
 
274
  def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
275
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
276
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
283
  output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
284
  return tokenizer.decode(output_ids, skip_special_tokens=True).strip()
285
 
 
286
  MAX_HISTORY_MESSAGES = getattr(config, "MAX_HISTORY_MESSAGES", 30)
287
 
288
  def build_messages_from_history(history: List[dict], system_prompt: str) -> List[dict]:
 
290
  msgs.extend(history)
291
  return msgs
292
 
 
293
  def strip_markdown(text: str) -> str:
 
 
 
294
  if not text:
295
  return ""
296
  text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
 
299
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
300
  return text
301
 
 
302
  def run_pipeline(user_query: str, session_id: str = None):
 
 
 
 
303
  if session_id is None:
304
+ session_id = str(uuid.uuid4())
305
 
 
306
  lang_label, prob = detect_language(user_query, top_k=1)[0]
307
  if lang_label not in SUPPORTED_LANGS:
308
  lang_label = "eng_Latn"
309
+
310
+ print(f"Detected language: {SUPPORTED_LANGS.get(lang_label, 'Unknown')}")
311
 
312
  translated_query = (
313
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
314
  if lang_label != "eng_Latn"
315
  else user_query
316
  )
317
+
318
+ print(f"Translated to English: {translated_query[:100]}...")
319
 
320
  intent, extra = detect_intent(translated_query)
321
 
 
322
  history = memory_store.get_history(session_id) or []
323
  if len(history) > MAX_HISTORY_MESSAGES:
324
  history = history[-MAX_HISTORY_MESSAGES:]
325
 
 
326
  history.append({"role": "user", "content": translated_query})
327
 
 
328
  system_prompt = (
329
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
330
  "Answer directly without repeating the question. "
331
  "Use clear farmer-friendly English with emojis . "
332
  "Avoid jargon and irrelevant details. "
333
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
 
334
  )
335
 
 
336
  if intent == "weather" and extra:
337
  weather_text = get_weather(extra)
338
  history.append({"role": "user", "content": f"Rewrite this weather update simply for farmers:\n{weather_text}"})
 
351
  messages_for_qwen = build_messages_from_history(history, system_prompt)
352
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
353
 
 
354
  history.append({"role": "assistant", "content": english_answer})
355
  if len(history) > MAX_HISTORY_MESSAGES:
356
  history = history[-MAX_HISTORY_MESSAGES:]
357
  memory_store.save_history(session_id, history)
358
 
 
359
  final_answer = (
360
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
361
  if lang_label != "eng_Latn"
362
  else english_answer
363
  )
364
  final_answer = strip_markdown(final_answer)
365
+
366
+ print(f"Final answer: {final_answer[:100]}...")
367
+
368
  return {
369
  "session_id": session_id,
370
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),