drrobot9 commited on
Commit
70ac964
·
1 Parent(s): 620a683

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +15 -92
app/agents/crew_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- # farmlingua/app/agents/crew_pipeline.py
2
  import os
3
  import sys
4
  import re
@@ -10,12 +10,13 @@ import numpy as np
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
- from app.utils.memory import memory_store
17
  from typing import List
18
 
 
19
  hf_cache = "/models/huggingface"
20
  os.environ["HF_HOME"] = hf_cache
21
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
@@ -28,11 +29,13 @@ if BASE_DIR not in sys.path:
28
 
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
 
 
31
  try:
32
  classifier = joblib.load(config.CLASSIFIER_PATH)
33
  except Exception:
34
  classifier = None
35
 
 
36
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
37
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
38
  model = AutoModelForCausalLM.from_pretrained(
@@ -41,9 +44,10 @@ model = AutoModelForCausalLM.from_pretrained(
41
  device_map="auto"
42
  )
43
 
 
44
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
45
 
46
- # Language detector
47
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
48
  lang_model_path = hf_hub_download(
49
  repo_id=config.LANG_ID_MODEL_REPO,
@@ -58,33 +62,11 @@ def detect_language(text: str, top_k: int = 1):
58
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
59
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
60
 
61
-
62
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
63
-
64
-
65
- LANG_CODE_MAP = {
66
- "eng_Latn": "en", # English
67
- "ibo_Latn": "ig", # Igbo
68
- "yor_Latn": "yo", # Yoruba
69
- "hau_Latn": "ha", # Hausa
70
- "swh_Latn": "sw", # Swahili
71
- "amh_Latn": "am", # Amharic
72
- }
73
-
74
-
75
- translation_tokenizer = AutoTokenizer.from_pretrained(
76
- config.TRANSLATION_MODEL_NAME
77
- )
78
- translation_model = AutoModelForSeq2SeqLM.from_pretrained(
79
- config.TRANSLATION_MODEL_NAME,
80
- device_map="auto" if DEVICE == "cuda" else None
81
- )
82
-
83
-
84
  translation_pipeline = pipeline(
85
- "translation",
86
- model=translation_model,
87
- tokenizer=translation_tokenizer,
88
  device=0 if DEVICE == "cuda" else -1,
89
  max_new_tokens=400,
90
  )
@@ -98,7 +80,7 @@ SUPPORTED_LANGS = {
98
  "amh_Latn": "Amharic",
99
  }
100
 
101
-
102
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
103
 
104
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
@@ -120,75 +102,16 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
120
  return chunks
121
 
122
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
123
- """Translate text between languages using the model"""
124
  if not text.strip():
125
  return text
126
-
127
-
128
- src_code = LANG_CODE_MAP.get(src_lang, "en")
129
- tgt_code = LANG_CODE_MAP.get(tgt_lang, "en")
130
-
131
-
132
- if src_code == tgt_code:
133
- return text
134
-
135
  chunks = chunk_text(text, max_len=max_chunk_len)
136
  translated_parts = []
137
-
138
  for chunk in chunks:
139
- try:
140
-
141
- if hasattr(translation_tokenizer, 'lang_code_to_id'):
142
- # Set source and target language
143
- translation_tokenizer.src_lang = src_code
144
- forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
145
-
146
- # Tokenize
147
- inputs = translation_tokenizer(chunk, return_tensors="pt")
148
- if DEVICE == "cuda":
149
- inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
150
-
151
- # Generate translation
152
- generated_tokens = translation_model.generate(
153
- **inputs,
154
- forced_bos_token_id=forced_bos_token_id,
155
- max_new_tokens=400
156
- )
157
-
158
- # Decode
159
- result = translation_tokenizer.batch_decode(
160
- generated_tokens,
161
- skip_special_tokens=True
162
- )[0]
163
-
164
- else:
165
-
166
- task_name = f"translation_{src_code}_to_{tgt_code}"
167
- try:
168
- specific_pipeline = pipeline(
169
- task_name,
170
- model=translation_model,
171
- tokenizer=translation_tokenizer,
172
- device=0 if DEVICE == "cuda" else -1,
173
- max_new_tokens=400,
174
- )
175
- result = specific_pipeline(chunk)[0]["translation_text"]
176
- except:
177
-
178
- result = translation_pipeline(
179
- chunk,
180
- src_lang=src_code,
181
- tgt_lang=tgt_code
182
- )[0]["translation_text"]
183
-
184
- translated_parts.append(result)
185
- except Exception as e:
186
- print(f"Translation error ({src_code}->{tgt_code}): {e}")
187
- translated_parts.append(chunk)
188
-
189
  return " ".join(translated_parts).strip()
190
 
191
-
192
  def retrieve_docs(query: str, vs_path: str):
193
  if not vs_path or not os.path.exists(vs_path):
194
  return None
 
1
+ farmlingua/app/agents/crew_pipeline.pymemorysection
2
  import os
3
  import sys
4
  import re
 
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
+ from app.utils.memory import memory_store # memory module
17
  from typing import List
18
 
19
+
20
  hf_cache = "/models/huggingface"
21
  os.environ["HF_HOME"] = hf_cache
22
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
 
29
 
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+
33
  try:
34
  classifier = joblib.load(config.CLASSIFIER_PATH)
35
  except Exception:
36
  classifier = None
37
 
38
+
39
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
40
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
41
  model = AutoModelForCausalLM.from_pretrained(
 
44
  device_map="auto"
45
  )
46
 
47
+
48
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
49
 
50
+ # language detector
51
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
52
  lang_model_path = hf_hub_download(
53
  repo_id=config.LANG_ID_MODEL_REPO,
 
62
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
63
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
64
 
65
+ # Translation model
66
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  translation_pipeline = pipeline(
68
+ "translation_en_to_fr",
69
+ model=config.TRANSLATION_MODEL_NAME,
 
70
  device=0 if DEVICE == "cuda" else -1,
71
  max_new_tokens=400,
72
  )
 
80
  "amh_Latn": "Amharic",
81
  }
82
 
83
+ # Text chunking
84
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
85
 
86
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
 
102
  return chunks
103
 
104
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
 
105
  if not text.strip():
106
  return text
 
 
 
 
 
 
 
 
 
107
  chunks = chunk_text(text, max_len=max_chunk_len)
108
  translated_parts = []
 
109
  for chunk in chunks:
110
+ res = translation_pipeline(chunk, src_lang=src_lang, tgt_lang=tgt_lang)
111
+ translated_parts.append(res[0]["translation_text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return " ".join(translated_parts).strip()
113
 
114
+ # RAG retrieval
115
  def retrieve_docs(query: str, vs_path: str):
116
  if not vs_path or not os.path.exists(vs_path):
117
  return None