drrobot9 commited on
Commit
edfc902
·
1 Parent(s): a3aa01e

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +43 -100
app/agents/crew_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- # farmlingua/app/agents/crew_pipeline.py
2
  import os
3
  import sys
4
  import re
@@ -10,12 +10,13 @@ import numpy as np
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
- from app.utils.memory import memory_store
17
  from typing import List
18
 
 
19
  hf_cache = "/models/huggingface"
20
  os.environ["HF_HOME"] = hf_cache
21
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
@@ -28,11 +29,13 @@ if BASE_DIR not in sys.path:
28
 
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
 
 
31
  try:
32
  classifier = joblib.load(config.CLASSIFIER_PATH)
33
  except Exception:
34
  classifier = None
35
 
 
36
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
37
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
38
  model = AutoModelForCausalLM.from_pretrained(
@@ -41,8 +44,10 @@ model = AutoModelForCausalLM.from_pretrained(
41
  device_map="auto"
42
  )
43
 
 
44
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
45
 
 
46
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
47
  lang_model_path = hf_hub_download(
48
  repo_id=config.LANG_ID_MODEL_REPO,
@@ -57,25 +62,15 @@ def detect_language(text: str, top_k: int = 1):
57
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
58
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
59
 
 
60
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
61
- translation_tokenizer = AutoTokenizer.from_pretrained(config.TRANSLATION_MODEL_NAME)
62
- translation_model = AutoModelForSeq2SeqLM.from_pretrained(
63
- config.TRANSLATION_MODEL_NAME,
64
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
65
- device_map="auto" if DEVICE == "cuda" else None
66
  )
67
 
68
- print("Translation model loaded successfully")
69
-
70
- LANG_CODE_MAP = {
71
- "eng_Latn": "eng_Latn",
72
- "ibo_Latn": "ibo_Latn",
73
- "yor_Latn": "yor_Latn",
74
- "hau_Latn": "hau_Latn",
75
- "swh_Latn": "swa_Latn",
76
- "amh_Latn": "amh_Ethi",
77
- }
78
-
79
  SUPPORTED_LANGS = {
80
  "eng_Latn": "English",
81
  "ibo_Latn": "Igbo",
@@ -85,6 +80,7 @@ SUPPORTED_LANGS = {
85
  "amh_Latn": "Amharic",
86
  }
87
 
 
88
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
89
 
90
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
@@ -106,83 +102,16 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
106
  return chunks
107
 
108
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
109
- print(f"\n[TRANSLATION] {src_lang} → {tgt_lang}")
110
- print(f" Input: {text[:100]}...")
111
-
112
- if not text.strip() or src_lang == tgt_lang:
113
- print(" No translation needed (same language)")
114
  return text
115
-
116
- src_code = LANG_CODE_MAP.get(src_lang, "eng_Latn")
117
- tgt_code = LANG_CODE_MAP.get(tgt_lang, "eng_Latn")
118
-
119
- print(f" Using codes: {src_code} → {tgt_code}")
120
-
121
- if src_code != "eng_Latn" and tgt_code != "eng_Latn":
122
- print(f" WARNING: Model wasn't trained on {src_code}→{tgt_code}")
123
- print(f" Will translate {src_code}→eng_Latn→{tgt_code}")
124
- to_english = translate_text_single(text, src_code, "eng_Latn", max_chunk_len)
125
- return translate_text_single(to_english, "eng_Latn", tgt_code, max_chunk_len)
126
-
127
- return translate_text_single(text, src_code, tgt_code, max_chunk_len)
128
-
129
- def translate_text_single(text: str, src_code: str, tgt_code: str, max_chunk_len: int = 400) -> str:
130
- supported_pairs = [
131
- ("eng_Latn", "ibo_Latn"), ("ibo_Latn", "eng_Latn"),
132
- ("eng_Latn", "yor_Latn"), ("yor_Latn", "eng_Latn"),
133
- ("eng_Latn", "hau_Latn"), ("hau_Latn", "eng_Latn"),
134
- ("eng_Latn", "swa_Latn"), ("swa_Latn", "eng_Latn"),
135
- ("eng_Latn", "amh_Ethi"), ("amh_Ethi", "eng_Latn"),
136
- ]
137
-
138
- if (src_code, tgt_code) not in supported_pairs:
139
- print(f" WARNING: Pair {src_code}→{tgt_code} may not work well")
140
-
141
  chunks = chunk_text(text, max_len=max_chunk_len)
142
  translated_parts = []
143
-
144
- for i, chunk in enumerate(chunks):
145
- print(f" Chunk {i+1}/{len(chunks)}: '{chunk[:50]}...'")
146
-
147
- try:
148
- input_text = f"{src_code} {chunk}"
149
-
150
- inputs = translation_tokenizer(
151
- input_text,
152
- return_tensors="pt",
153
- truncation=True,
154
- max_length=512
155
- )
156
-
157
- if DEVICE == "cuda":
158
- inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
159
-
160
- generated_tokens = translation_model.generate(
161
- **inputs,
162
- max_new_tokens=400,
163
- num_beams=4,
164
- early_stopping=True
165
- )
166
-
167
- result = translation_tokenizer.batch_decode(
168
- generated_tokens,
169
- skip_special_tokens=True
170
- )[0]
171
-
172
- if result.startswith(tgt_code + " "):
173
- result = result[len(tgt_code) + 1:]
174
-
175
- print(f" → '{result[:50]}...'")
176
- translated_parts.append(result.strip())
177
-
178
- except Exception as e:
179
- print(f" ERROR: {e}")
180
- translated_parts.append(chunk)
181
-
182
- final_result = " ".join(translated_parts).strip()
183
- print(f" Final: '{final_result[:100]}...'")
184
- return final_result
185
 
 
186
  def retrieve_docs(query: str, vs_path: str):
187
  if not vs_path or not os.path.exists(vs_path):
188
  return None
@@ -202,6 +131,7 @@ def retrieve_docs(query: str, vs_path: str):
202
  return "\n\n".join(docs) if docs else None
203
  return None
204
 
 
205
  def get_weather(state_name: str) -> str:
206
  url = "http://api.weatherapi.com/v1/current.json"
207
  params = {"key": config.WEATHER_API_KEY, "q": f"{state_name}, Nigeria", "aqi": "no"}
@@ -217,6 +147,7 @@ def get_weather(state_name: str) -> str:
217
  f"- Wind: {data['current']['wind_kph']} kph"
218
  )
219
 
 
220
  def detect_intent(query: str):
221
  q_lower = (query or "").lower()
222
  if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]):
@@ -239,6 +170,7 @@ def detect_intent(query: str):
239
  pass
240
  return "normal", None
241
 
 
242
  def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
243
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
244
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
@@ -251,6 +183,7 @@ def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
251
  output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
252
  return tokenizer.decode(output_ids, skip_special_tokens=True).strip()
253
 
 
254
  MAX_HISTORY_MESSAGES = getattr(config, "MAX_HISTORY_MESSAGES", 30)
255
 
256
  def build_messages_from_history(history: List[dict], system_prompt: str) -> List[dict]:
@@ -258,7 +191,11 @@ def build_messages_from_history(history: List[dict], system_prompt: str) -> List
258
  msgs.extend(history)
259
  return msgs
260
 
 
261
  def strip_markdown(text: str) -> str:
 
 
 
262
  if not text:
263
  return ""
264
  text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
@@ -267,40 +204,47 @@ def strip_markdown(text: str) -> str:
267
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
268
  return text
269
 
 
270
  def run_pipeline(user_query: str, session_id: str = None):
 
 
 
 
271
  if session_id is None:
272
- session_id = str(uuid.uuid4())
273
 
 
274
  lang_label, prob = detect_language(user_query, top_k=1)[0]
275
  if lang_label not in SUPPORTED_LANGS:
276
  lang_label = "eng_Latn"
277
-
278
- print(f"Detected language: {SUPPORTED_LANGS.get(lang_label, 'Unknown')}")
279
 
280
  translated_query = (
281
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
282
  if lang_label != "eng_Latn"
283
  else user_query
284
  )
285
-
286
- print(f"Translated to English: {translated_query[:100]}...")
287
 
288
  intent, extra = detect_intent(translated_query)
289
 
 
290
  history = memory_store.get_history(session_id) or []
291
  if len(history) > MAX_HISTORY_MESSAGES:
292
  history = history[-MAX_HISTORY_MESSAGES:]
293
 
 
294
  history.append({"role": "user", "content": translated_query})
295
 
 
296
  system_prompt = (
297
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
298
  "Answer directly without repeating the question. "
299
  "Use clear farmer-friendly English with emojis . "
300
  "Avoid jargon and irrelevant details. "
301
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
 
302
  )
303
 
 
304
  if intent == "weather" and extra:
305
  weather_text = get_weather(extra)
306
  history.append({"role": "user", "content": f"Rewrite this weather update simply for farmers:\n{weather_text}"})
@@ -319,20 +263,19 @@ def run_pipeline(user_query: str, session_id: str = None):
319
  messages_for_qwen = build_messages_from_history(history, system_prompt)
320
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
321
 
 
322
  history.append({"role": "assistant", "content": english_answer})
323
  if len(history) > MAX_HISTORY_MESSAGES:
324
  history = history[-MAX_HISTORY_MESSAGES:]
325
  memory_store.save_history(session_id, history)
326
 
 
327
  final_answer = (
328
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
329
  if lang_label != "eng_Latn"
330
  else english_answer
331
  )
332
  final_answer = strip_markdown(final_answer)
333
-
334
- print(f"Final answer: {final_answer[:100]}...")
335
-
336
  return {
337
  "session_id": session_id,
338
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),
 
1
+ # farmlingua/app/agents/crew_pipeline.pymemorysection
2
  import os
3
  import sys
4
  import re
 
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
+ from app.utils.memory import memory_store # memory module
17
  from typing import List
18
 
19
+
20
  hf_cache = "/models/huggingface"
21
  os.environ["HF_HOME"] = hf_cache
22
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
 
29
 
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+
33
  try:
34
  classifier = joblib.load(config.CLASSIFIER_PATH)
35
  except Exception:
36
  classifier = None
37
 
38
+
39
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
40
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
41
  model = AutoModelForCausalLM.from_pretrained(
 
44
  device_map="auto"
45
  )
46
 
47
+
48
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
49
 
50
+ # language detector
51
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
52
  lang_model_path = hf_hub_download(
53
  repo_id=config.LANG_ID_MODEL_REPO,
 
62
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
63
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
64
 
65
+ # Translation model
66
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
67
+ translation_pipeline = pipeline(
68
+ "translation",
69
+ model=config.TRANSLATION_MODEL_NAME,
70
+ device=0 if DEVICE == "cuda" else -1,
71
+ max_new_tokens=400,
72
  )
73
 
 
 
 
 
 
 
 
 
 
 
 
74
  SUPPORTED_LANGS = {
75
  "eng_Latn": "English",
76
  "ibo_Latn": "Igbo",
 
80
  "amh_Latn": "Amharic",
81
  }
82
 
83
+ # Text chunking
84
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
85
 
86
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
 
102
  return chunks
103
 
104
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
105
+ if not text.strip():
 
 
 
 
106
  return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  chunks = chunk_text(text, max_len=max_chunk_len)
108
  translated_parts = []
109
+ for chunk in chunks:
110
+ res = translation_pipeline(chunk, src_lang=src_lang, tgt_lang=tgt_lang)
111
+ translated_parts.append(res[0]["translation_text"])
112
+ return " ".join(translated_parts).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ # RAG retrieval
115
  def retrieve_docs(query: str, vs_path: str):
116
  if not vs_path or not os.path.exists(vs_path):
117
  return None
 
131
  return "\n\n".join(docs) if docs else None
132
  return None
133
 
134
+
135
  def get_weather(state_name: str) -> str:
136
  url = "http://api.weatherapi.com/v1/current.json"
137
  params = {"key": config.WEATHER_API_KEY, "q": f"{state_name}, Nigeria", "aqi": "no"}
 
147
  f"- Wind: {data['current']['wind_kph']} kph"
148
  )
149
 
150
+
151
  def detect_intent(query: str):
152
  q_lower = (query or "").lower()
153
  if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]):
 
170
  pass
171
  return "normal", None
172
 
173
+ # expert runner
174
  def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str:
175
  text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
176
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
183
  output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist()
184
  return tokenizer.decode(output_ids, skip_special_tokens=True).strip()
185
 
186
+ # Memory
187
  MAX_HISTORY_MESSAGES = getattr(config, "MAX_HISTORY_MESSAGES", 30)
188
 
189
  def build_messages_from_history(history: List[dict], system_prompt: str) -> List[dict]:
 
191
  msgs.extend(history)
192
  return msgs
193
 
194
+
195
  def strip_markdown(text: str) -> str:
196
+ """
197
+ Remove Markdown formatting like **bold**, *italic*, and `inline code`.
198
+ """
199
  if not text:
200
  return ""
201
  text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
 
204
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
205
  return text
206
 
207
+ # Main pipeline
208
  def run_pipeline(user_query: str, session_id: str = None):
209
+ """
210
+ Run FarmLingua pipeline with per-session memory.
211
+ Each session_id keeps its own history.
212
+ """
213
  if session_id is None:
214
+ session_id = str(uuid.uuid4()) # fallback unique session
215
 
216
+ # Language detection
217
  lang_label, prob = detect_language(user_query, top_k=1)[0]
218
  if lang_label not in SUPPORTED_LANGS:
219
  lang_label = "eng_Latn"
 
 
220
 
221
  translated_query = (
222
  translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn")
223
  if lang_label != "eng_Latn"
224
  else user_query
225
  )
 
 
226
 
227
  intent, extra = detect_intent(translated_query)
228
 
229
+ # Load conversation history
230
  history = memory_store.get_history(session_id) or []
231
  if len(history) > MAX_HISTORY_MESSAGES:
232
  history = history[-MAX_HISTORY_MESSAGES:]
233
 
234
+
235
  history.append({"role": "user", "content": translated_query})
236
 
237
+
238
  system_prompt = (
239
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
240
  "Answer directly without repeating the question. "
241
  "Use clear farmer-friendly English with emojis . "
242
  "Avoid jargon and irrelevant details. "
243
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
244
+
245
  )
246
 
247
+
248
  if intent == "weather" and extra:
249
  weather_text = get_weather(extra)
250
  history.append({"role": "user", "content": f"Rewrite this weather update simply for farmers:\n{weather_text}"})
 
263
  messages_for_qwen = build_messages_from_history(history, system_prompt)
264
  english_answer = run_qwen(messages_for_qwen, max_new_tokens=700)
265
 
266
+ # Save assistant reply
267
  history.append({"role": "assistant", "content": english_answer})
268
  if len(history) > MAX_HISTORY_MESSAGES:
269
  history = history[-MAX_HISTORY_MESSAGES:]
270
  memory_store.save_history(session_id, history)
271
 
272
+ # Translate back if needed
273
  final_answer = (
274
  translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label)
275
  if lang_label != "eng_Latn"
276
  else english_answer
277
  )
278
  final_answer = strip_markdown(final_answer)
 
 
 
279
  return {
280
  "session_id": session_id,
281
  "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"),