drrobot9 commited on
Commit
2405881
·
verified ·
1 Parent(s): 5c1f411

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +42 -31
app/agents/crew_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- # farmlingua/app/agents/crew_pipeline.pymemorysection
2
  import os
3
  import sys
4
  import re
@@ -10,7 +10,7 @@ import numpy as np
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
  from app.utils.memory import memory_store # memory module
@@ -64,26 +64,24 @@ def detect_language(text: str, top_k: int = 1):
64
 
65
  # Translation model
66
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
67
- from transformers import AutoModelForSeq2SeqLM
68
 
69
-
70
-
71
- translation_tokenizer = AutoTokenizer.from_pretrained(
72
  config.TRANSLATION_MODEL_NAME,
73
- use_fast = True
74
  )
75
 
76
  translation_model = AutoModelForSeq2SeqLM.from_pretrained(
77
- config.TRANSLATION_MODEL_NAME,
78
- torch_dtype = 'auto',
79
- device_map = 'auto' if DEVICE == 'cuda' else None
80
-
81
  )
82
- if DEVICE == 'cpu':
83
- translation_model = translation_model.to('cpu')
84
 
 
 
 
 
85
 
86
-
87
 
88
  SUPPORTED_LANGS = {
89
  "eng_Latn": "English",
@@ -116,37 +114,50 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
116
  return chunks
117
 
118
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
 
119
  if not text.strip():
120
  return text
 
 
 
 
121
  chunks = chunk_text(text, max_len=max_chunk_len)
122
  translated_parts = []
 
123
  for chunk in chunks:
 
 
 
 
124
  inputs = translation_tokenizer(
125
  chunk,
126
- return_tensors = 'pt',
127
- padding = True,
128
- truncation = True,
129
- max_length = 400
130
  ).to(translation_model.device)
131
-
132
- #setting the target language token
133
  forced_bos_token_id = translation_tokenizer.convert_tokens_to_ids(tgt_lang)
134
-
 
135
  generated_tokens = translation_model.generate(
136
  **inputs,
137
- forced_bos_token_id = forced_bos_token_id,
138
- max_new_tokens = 400,
139
- num_beams = 5,
140
- early_stopping = True
141
  )
142
-
 
143
  translated_text = translation_tokenizer.batch_decode(
144
  generated_tokens,
145
- skip_special_tokens = True
146
  )[0]
 
147
  translated_parts.append(translated_text)
148
-
149
- return "".join(translated_parts).strip()
150
 
151
 
152
  # RAG retrieval
@@ -242,7 +253,7 @@ def strip_markdown(text: str) -> str:
242
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
243
  return text
244
 
245
- # Main pipeline
246
  def run_pipeline(user_query: str, session_id: str = None):
247
  """
248
  Run FarmLingua pipeline with per-session memory.
@@ -273,7 +284,7 @@ def run_pipeline(user_query: str, session_id: str = None):
273
  system_prompt = (
274
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
275
  "Answer questions directly and accurately with helpful farming advice. "
276
- "Use clear, simple language with occasional emojis 🌾. "
277
  "Be concise and focus on practical, actionable information. "
278
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
279
  )
 
1
+ # farmlingua/app/agents/crew_pipeline.py
2
  import os
3
  import sys
4
  import re
 
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, NllbTokenizer
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
  from app.utils.memory import memory_store # memory module
 
64
 
65
  # Translation model
66
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
 
67
 
68
+ translation_tokenizer = NllbTokenizer.from_pretrained(
 
 
69
  config.TRANSLATION_MODEL_NAME,
70
+ cache_dir=hf_cache
71
  )
72
 
73
  translation_model = AutoModelForSeq2SeqLM.from_pretrained(
74
+ config.TRANSLATION_MODEL_NAME,
75
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
76
+ cache_dir=hf_cache
 
77
  )
 
 
78
 
79
+ if DEVICE == "cuda":
80
+ translation_model = translation_model.to("cuda")
81
+ else:
82
+ translation_model = translation_model.to("cpu")
83
 
84
+ print(f"Translation model loaded on {DEVICE}")
85
 
86
  SUPPORTED_LANGS = {
87
  "eng_Latn": "English",
 
114
  return chunks
115
 
116
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
117
+ """Translate text using NLLB model"""
118
  if not text.strip():
119
  return text
120
+
121
+ if src_lang == tgt_lang:
122
+ return text
123
+
124
  chunks = chunk_text(text, max_len=max_chunk_len)
125
  translated_parts = []
126
+
127
  for chunk in chunks:
128
+
129
+ translation_tokenizer.src_lang = src_lang
130
+
131
+ # Tokenize
132
  inputs = translation_tokenizer(
133
  chunk,
134
+ return_tensors="pt",
135
+ padding=True,
136
+ truncation=True,
137
+ max_length=512
138
  ).to(translation_model.device)
139
+
140
+
141
  forced_bos_token_id = translation_tokenizer.convert_tokens_to_ids(tgt_lang)
142
+
143
+ # Generate translation
144
  generated_tokens = translation_model.generate(
145
  **inputs,
146
+ forced_bos_token_id=forced_bos_token_id,
147
+ max_new_tokens=512,
148
+ num_beams=5,
149
+ early_stopping=True
150
  )
151
+
152
+ # Decode
153
  translated_text = translation_tokenizer.batch_decode(
154
  generated_tokens,
155
+ skip_special_tokens=True
156
  )[0]
157
+
158
  translated_parts.append(translated_text)
159
+
160
+ return " ".join(translated_parts).strip()
161
 
162
 
163
  # RAG retrieval
 
253
  text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE)
254
  return text
255
 
256
+
257
  def run_pipeline(user_query: str, session_id: str = None):
258
  """
259
  Run FarmLingua pipeline with per-session memory.
 
284
  system_prompt = (
285
  "You are FarmLingua, an AI assistant for Nigerian farmers. "
286
  "Answer questions directly and accurately with helpful farming advice. "
287
+ "Use clear, simple language with occasional emojis . "
288
  "Be concise and focus on practical, actionable information. "
289
  "If asked who built you, say: 'KawaFarm LTD developed me to help farmers.'"
290
  )