drrobot9 commited on
Commit
620a683
·
1 Parent(s): e4c5a04

Update app/agents/crew_pipeline.py

Browse files
Files changed (1) hide show
  1. app/agents/crew_pipeline.py +91 -14
app/agents/crew_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- # farmlingua/app/agents/crew_pipeline.pymemorysection
2
  import os
3
  import sys
4
  import re
@@ -10,13 +10,12 @@ import numpy as np
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
- from app.utils.memory import memory_store # memory module
17
  from typing import List
18
 
19
-
20
  hf_cache = "/models/huggingface"
21
  os.environ["HF_HOME"] = hf_cache
22
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
@@ -29,13 +28,11 @@ if BASE_DIR not in sys.path:
29
 
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
-
33
  try:
34
  classifier = joblib.load(config.CLASSIFIER_PATH)
35
  except Exception:
36
  classifier = None
37
 
38
-
39
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
40
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
41
  model = AutoModelForCausalLM.from_pretrained(
@@ -44,10 +41,9 @@ model = AutoModelForCausalLM.from_pretrained(
44
  device_map="auto"
45
  )
46
 
47
-
48
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
49
 
50
- # language detector
51
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
52
  lang_model_path = hf_hub_download(
53
  repo_id=config.LANG_ID_MODEL_REPO,
@@ -62,11 +58,33 @@ def detect_language(text: str, top_k: int = 1):
62
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
63
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
64
 
65
- # Translation model
66
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  translation_pipeline = pipeline(
68
  "translation",
69
- model=config.TRANSLATION_MODEL_NAME,
 
70
  device=0 if DEVICE == "cuda" else -1,
71
  max_new_tokens=400,
72
  )
@@ -80,7 +98,7 @@ SUPPORTED_LANGS = {
80
  "amh_Latn": "Amharic",
81
  }
82
 
83
- # Text chunking
84
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
85
 
86
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
@@ -102,16 +120,75 @@ def chunk_text(text: str, max_len: int = 400) -> List[str]:
102
  return chunks
103
 
104
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
 
105
  if not text.strip():
106
  return text
 
 
 
 
 
 
 
 
 
107
  chunks = chunk_text(text, max_len=max_chunk_len)
108
  translated_parts = []
 
109
  for chunk in chunks:
110
- res = translation_pipeline(chunk, src_lang=src_lang, tgt_lang=tgt_lang)
111
- translated_parts.append(res[0]["translation_text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return " ".join(translated_parts).strip()
113
 
114
- # RAG retrieval
115
  def retrieve_docs(query: str, vs_path: str):
116
  if not vs_path or not os.path.exists(vs_path):
117
  return None
 
1
+ # farmlingua/app/agents/crew_pipeline.py
2
  import os
3
  import sys
4
  import re
 
10
  import torch
11
  import fasttext
12
  from huggingface_hub import hf_hub_download
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM
14
  from sentence_transformers import SentenceTransformer
15
  from app.utils import config
16
+ from app.utils.memory import memory_store
17
  from typing import List
18
 
 
19
  hf_cache = "/models/huggingface"
20
  os.environ["HF_HOME"] = hf_cache
21
  os.environ["TRANSFORMERS_CACHE"] = hf_cache
 
28
 
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
 
 
31
  try:
32
  classifier = joblib.load(config.CLASSIFIER_PATH)
33
  except Exception:
34
  classifier = None
35
 
 
36
  print(f"Loading expert model ({config.EXPERT_MODEL_NAME})...")
37
  tokenizer = AutoTokenizer.from_pretrained(config.EXPERT_MODEL_NAME, use_fast=False)
38
  model = AutoModelForCausalLM.from_pretrained(
 
41
  device_map="auto"
42
  )
43
 
 
44
  embedder = SentenceTransformer(config.EMBEDDING_MODEL)
45
 
46
+ # Language detector
47
  print(f"Loading FastText language identifier ({config.LANG_ID_MODEL_REPO})...")
48
  lang_model_path = hf_hub_download(
49
  repo_id=config.LANG_ID_MODEL_REPO,
 
58
  labels, probs = lang_identifier.predict(clean_text, k=top_k)
59
  return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)]
60
 
61
+
62
  print(f"Loading translation model ({config.TRANSLATION_MODEL_NAME})...")
63
+
64
+
65
+ LANG_CODE_MAP = {
66
+ "eng_Latn": "en", # English
67
+ "ibo_Latn": "ig", # Igbo
68
+ "yor_Latn": "yo", # Yoruba
69
+ "hau_Latn": "ha", # Hausa
70
+ "swh_Latn": "sw", # Swahili
71
+ "amh_Latn": "am", # Amharic
72
+ }
73
+
74
+
75
+ translation_tokenizer = AutoTokenizer.from_pretrained(
76
+ config.TRANSLATION_MODEL_NAME
77
+ )
78
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(
79
+ config.TRANSLATION_MODEL_NAME,
80
+ device_map="auto" if DEVICE == "cuda" else None
81
+ )
82
+
83
+
84
  translation_pipeline = pipeline(
85
  "translation",
86
+ model=translation_model,
87
+ tokenizer=translation_tokenizer,
88
  device=0 if DEVICE == "cuda" else -1,
89
  max_new_tokens=400,
90
  )
 
98
  "amh_Latn": "Amharic",
99
  }
100
 
101
+
102
  _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')
103
 
104
  def chunk_text(text: str, max_len: int = 400) -> List[str]:
 
120
  return chunks
121
 
122
  def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str:
123
+ """Translate text between languages using the model"""
124
  if not text.strip():
125
  return text
126
+
127
+
128
+ src_code = LANG_CODE_MAP.get(src_lang, "en")
129
+ tgt_code = LANG_CODE_MAP.get(tgt_lang, "en")
130
+
131
+
132
+ if src_code == tgt_code:
133
+ return text
134
+
135
  chunks = chunk_text(text, max_len=max_chunk_len)
136
  translated_parts = []
137
+
138
  for chunk in chunks:
139
+ try:
140
+
141
+ if hasattr(translation_tokenizer, 'lang_code_to_id'):
142
+ # Set source and target language
143
+ translation_tokenizer.src_lang = src_code
144
+ forced_bos_token_id = translation_tokenizer.lang_code_to_id[tgt_code]
145
+
146
+ # Tokenize
147
+ inputs = translation_tokenizer(chunk, return_tensors="pt")
148
+ if DEVICE == "cuda":
149
+ inputs = {k: v.to(translation_model.device) for k, v in inputs.items()}
150
+
151
+ # Generate translation
152
+ generated_tokens = translation_model.generate(
153
+ **inputs,
154
+ forced_bos_token_id=forced_bos_token_id,
155
+ max_new_tokens=400
156
+ )
157
+
158
+ # Decode
159
+ result = translation_tokenizer.batch_decode(
160
+ generated_tokens,
161
+ skip_special_tokens=True
162
+ )[0]
163
+
164
+ else:
165
+
166
+ task_name = f"translation_{src_code}_to_{tgt_code}"
167
+ try:
168
+ specific_pipeline = pipeline(
169
+ task_name,
170
+ model=translation_model,
171
+ tokenizer=translation_tokenizer,
172
+ device=0 if DEVICE == "cuda" else -1,
173
+ max_new_tokens=400,
174
+ )
175
+ result = specific_pipeline(chunk)[0]["translation_text"]
176
+ except:
177
+
178
+ result = translation_pipeline(
179
+ chunk,
180
+ src_lang=src_code,
181
+ tgt_lang=tgt_code
182
+ )[0]["translation_text"]
183
+
184
+ translated_parts.append(result)
185
+ except Exception as e:
186
+ print(f"Translation error ({src_code}->{tgt_code}): {e}")
187
+ translated_parts.append(chunk)
188
+
189
  return " ".join(translated_parts).strip()
190
 
191
+
192
  def retrieve_docs(query: str, vs_path: str):
193
  if not vs_path or not os.path.exists(vs_path):
194
  return None