| import os |
| import sys |
| import re |
| import uuid |
| import requests |
| import joblib |
| import faiss |
| import numpy as np |
| import torch |
| import fasttext |
| from huggingface_hub import hf_hub_download |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, NllbTokenizer |
| from sentence_transformers import SentenceTransformer |
| from app.utils import config |
| from app.utils.memory import memory_store |
| from app.utils.weather_api import forecast_summary_for_state |
| from typing import List |
|
|
|
|
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| if BASE_DIR not in sys.path: |
| sys.path.insert(0, BASE_DIR) |
|
|
| from app.utils.model_manager import ( |
| load_expert_model, |
| load_translation_model, |
| load_embedder, |
| load_lang_identifier, |
| load_classifier, |
| get_device |
| ) |
|
|
| DEVICE = get_device() |
| _tokenizer = None |
| _model = None |
| _embedder = None |
| _lang_identifier = None |
| _translation_tokenizer = None |
| _translation_model = None |
| _classifier = None |
|
|
|
|
| def get_expert_model(): |
| global _tokenizer, _model |
| if _tokenizer is None or _model is None: |
| _tokenizer, _model = load_expert_model(config.EXPERT_MODEL_NAME, use_quantization=True) |
| return _tokenizer, _model |
|
|
|
|
| def get_embedder(): |
| global _embedder |
| if _embedder is None: |
| _embedder = load_embedder(config.EMBEDDING_MODEL) |
| return _embedder |
|
|
|
|
| def get_lang_identifier(): |
| global _lang_identifier |
| if _lang_identifier is None: |
| _lang_identifier = load_lang_identifier( |
| config.LANG_ID_MODEL_REPO, |
| getattr(config, "LANG_ID_MODEL_FILE", "model.bin") |
| ) |
| return _lang_identifier |
|
|
|
|
| def get_translation_model(): |
| global _translation_tokenizer, _translation_model |
| if _translation_tokenizer is None or _translation_model is None: |
| _translation_tokenizer, _translation_model = load_translation_model(config.TRANSLATION_MODEL_NAME) |
| return _translation_tokenizer, _translation_model |
|
|
|
|
| def get_classifier(): |
| global _classifier |
| if _classifier is None: |
| _classifier = load_classifier(config.CLASSIFIER_PATH) |
| return _classifier |
|
|
| def detect_language(text: str, top_k: int = 1): |
| if not text or not text.strip(): |
| return [("eng_Latn", 1.0)] |
| lang_identifier = get_lang_identifier() |
| clean_text = text.replace("\n", " ").strip() |
| labels, probs = lang_identifier.predict(clean_text, k=top_k) |
| return [(l.replace("__label__", ""), float(p)) for l, p in zip(labels, probs)] |
|
|
| SUPPORTED_LANGS = { |
| "eng_Latn": "English", |
| "ibo_Latn": "Igbo", |
| "yor_Latn": "Yoruba", |
| "hau_Latn": "Hausa", |
| "swh_Latn": "Swahili", |
| "amh_Latn": "Amharic", |
| } |
|
|
| _SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+') |
|
|
| def chunk_text(text: str, max_len: int = 400) -> List[str]: |
| if not text: |
| return [] |
| sentences = _SENTENCE_SPLIT_RE.split(text) |
| chunks, current = [], "" |
| for s in sentences: |
| if not s: |
| continue |
| if len(current) + len(s) + 1 <= max_len: |
| current = (current + " " + s).strip() |
| else: |
| if current: |
| chunks.append(current.strip()) |
| current = s.strip() |
| if current: |
| chunks.append(current.strip()) |
| return chunks |
|
|
| def translate_text(text: str, src_lang: str, tgt_lang: str, max_chunk_len: int = 400) -> str: |
| """Translate text using NLLB model""" |
| if not text.strip(): |
| return text |
| |
| if src_lang == tgt_lang: |
| return text |
| |
| translation_tokenizer, translation_model = get_translation_model() |
| |
| chunks = chunk_text(text, max_len=max_chunk_len) |
| translated_parts = [] |
| |
| for chunk in chunks: |
| |
| translation_tokenizer.src_lang = src_lang |
| |
| |
| inputs = translation_tokenizer( |
| chunk, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=512 |
| ).to(translation_model.device) |
| |
| |
| forced_bos_token_id = translation_tokenizer.convert_tokens_to_ids(tgt_lang) |
| |
| |
| generated_tokens = translation_model.generate( |
| **inputs, |
| forced_bos_token_id=forced_bos_token_id, |
| max_new_tokens=512, |
| num_beams=5, |
| early_stopping=True |
| ) |
| |
| |
| translated_text = translation_tokenizer.batch_decode( |
| generated_tokens, |
| skip_special_tokens=True |
| )[0] |
| |
| translated_parts.append(translated_text) |
| |
| return " ".join(translated_parts).strip() |
|
|
|
|
| |
| def retrieve_docs(query: str, vs_path: str): |
| if not vs_path or not os.path.exists(vs_path): |
| return None |
| try: |
| index = faiss.read_index(str(vs_path)) |
| except Exception: |
| return None |
| embedder = get_embedder() |
| query_vec = np.array([embedder.encode(query)], dtype=np.float32) |
| D, I = index.search(query_vec, k=3) |
| if D[0][0] == 0: |
| return None |
| meta_path = str(vs_path) + "_meta.npy" |
| if os.path.exists(meta_path): |
| metadata = np.load(meta_path, allow_pickle=True).item() |
| docs = [metadata.get(str(idx), "") for idx in I[0] if str(idx) in metadata] |
| docs = [d for d in docs if d] |
| return "\n\n".join(docs) if docs else None |
| return None |
|
|
|
|
| def get_weather(state_name: str) -> str: |
| try: |
| return forecast_summary_for_state(state_name) |
| except Exception: |
| return f"Unable to retrieve weather for {state_name}." |
|
|
|
|
| def detect_intent(query: str): |
| q_lower = (query or "").lower() |
| if any(word in q_lower for word in ["weather", "temperature", "rain", "forecast"]): |
| for state in getattr(config, "STATES", []): |
| if state.lower() in q_lower: |
| return "weather", state |
| return "weather", None |
|
|
| if any(word in q_lower for word in ["latest", "update", "breaking", "news", "current", "predict"]): |
| return "live_update", None |
|
|
| classifier = get_classifier() |
| if classifier and hasattr(classifier, "predict") and hasattr(classifier, "predict_proba"): |
| try: |
| predicted_intent = classifier.predict([query])[0] |
| confidence = max(classifier.predict_proba([query])[0]) |
| if confidence < getattr(config, "CLASSIFIER_CONFIDENCE_THRESHOLD", 0.6): |
| return "low_confidence", None |
| return predicted_intent, None |
| except Exception: |
| pass |
| return "normal", None |
|
|
| |
| def run_qwen(messages: List[dict], max_new_tokens: int = 1300) -> str: |
| tokenizer, model = get_expert_model() |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer([text], return_tensors="pt").to(model.device) |
| |
| |
| stop_sequences = ["\n\nHuman:", "\nHuman:", "Human:", "\n\nAssistant:", "\nAssistant:"] |
| stop_token_ids = [] |
| for seq in stop_sequences: |
| tokens = tokenizer.encode(seq, add_special_tokens=False) |
| if tokens: |
| stop_token_ids.extend(tokens) |
| |
| generated_ids = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=0.4, |
| repetition_penalty=1.1, |
| do_sample=True, |
| top_p=0.9, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id |
| ) |
| output_ids = generated_ids[0][len(inputs.input_ids[0]):].tolist() |
| response = tokenizer.decode(output_ids, skip_special_tokens=True).strip() |
| |
| |
| |
| if "Human:" in response or "\nHuman:" in response: |
| |
| parts = re.split(r'\n?\n?Human:', response, maxsplit=1) |
| response = parts[0].strip() |
| |
| |
| |
| if '\n\n' in response: |
| parts = response.split('\n\n') |
| cleaned_parts = [] |
| for part in parts: |
| |
| unrelated_keywords = ["London", "get around", "parks", "neighborhoods", "festivals", |
| "Wimbledon", "Notting Hill", "Covent Garden", "travel", "tourism"] |
| if any(keyword.lower() in part.lower() for keyword in unrelated_keywords): |
| |
| if not any(ag_keyword in part.lower() for ag_keyword in ["farm", "crop", "livestock", "agriculture", "soil", "weather"]): |
| continue |
| cleaned_parts.append(part) |
| response = '\n\n'.join(cleaned_parts).strip() |
| |
| |
| lines = response.split('\n') |
| cleaned_lines = [] |
| found_example_marker = False |
| for line in lines: |
| |
| if line.strip().startswith(("Human:", "Assistant:", "User:", "Bot:")): |
| found_example_marker = True |
| break |
| |
| if re.match(r'^\d+\.\s+(London|get around|parks|neighborhoods)', line, re.IGNORECASE): |
| found_example_marker = True |
| break |
| cleaned_lines.append(line) |
| |
| cleaned_response = '\n'.join(cleaned_lines).strip() |
| |
| |
| if found_example_marker and len(cleaned_response) > 200: |
| |
| first_para = cleaned_response.split('\n\n')[0] if '\n\n' in cleaned_response else cleaned_response[:200] |
| cleaned_response = first_para.strip() |
| |
| return cleaned_response |
|
|
| |
| MAX_HISTORY_MESSAGES = getattr(config, "MAX_HISTORY_MESSAGES", 30) |
|
|
| def build_messages_from_history(history: List[dict], system_prompt: str) -> List[dict]: |
| msgs = [{"role": "system", "content": system_prompt}] |
| msgs.extend(history) |
| return msgs |
|
|
|
|
| def strip_markdown(text: str) -> str: |
| """ |
| Remove Markdown formatting like **bold**, *italic*, and `inline code`. |
| """ |
| if not text: |
| return "" |
| text = re.sub(r'\*\*(.*?)\*\*', r'\1', text) |
| text = re.sub(r'(\*|_)(.*?)\1', r'\2', text) |
| text = re.sub(r'`(.*?)`', r'\1', text) |
| text = re.sub(r'^#+\s+', '', text, flags=re.MULTILINE) |
| return text |
|
|
|
|
| def run_pipeline(user_query: str, session_id: str = None): |
| """ |
| Run Aglimate pipeline with per-session memory. |
| Each session_id keeps its own history. |
| """ |
| if session_id is None: |
| session_id = str(uuid.uuid4()) |
|
|
| |
| lang_label, prob = detect_language(user_query, top_k=1)[0] |
| if lang_label not in SUPPORTED_LANGS: |
| lang_label = "eng_Latn" |
|
|
| translated_query = ( |
| translate_text(user_query, src_lang=lang_label, tgt_lang="eng_Latn") |
| if lang_label != "eng_Latn" |
| else user_query |
| ) |
|
|
| intent, extra = detect_intent(translated_query) |
|
|
| |
| history = memory_store.get_history(session_id) or [] |
| if len(history) > MAX_HISTORY_MESSAGES: |
| history = history[-MAX_HISTORY_MESSAGES:] |
|
|
| |
| system_prompt = ( |
| "You are Aglimate, an AI assistant for Nigerian farmers developed by Ifeanyi Amogu Shalom. " |
| "Your role is to provide helpful farming advice, agricultural information, and support for Nigerian farmers. " |
| "\n\nIMPORTANT RULES:" |
| "\n1. ONLY answer questions related to agriculture, farming, crops, livestock, weather, soil, and farming in Nigeria/Africa." |
| "\n2. If asked who you are, say: 'I am Aglimate, an AI assistant developed by Ifeanyi Amogu Shalom to help Nigerian farmers with agricultural advice.'" |
| "\n3. Do NOT provide information about unrelated topics (like travel, cities, non-agricultural topics)." |
| "\n4. If a question is not related to farming/agriculture, politely redirect: 'I specialize in agricultural advice for Nigerian farmers. How can I help with your farming questions?'" |
| "\n5. Use clear, simple language with occasional emojis." |
| "\n6. Be concise and focus on practical, actionable information." |
| "\n7. Do NOT include example conversations or unrelated content in your responses." |
| "\n8. Answer ONLY the current question asked - do not add extra examples or unrelated information." |
| ) |
|
|
|
|
| context_info = "" |
| |
| if intent == "weather" and extra: |
| weather_text = get_weather(extra) |
| context_info = f"\n\nCurrent weather information:\n{weather_text}" |
| elif intent == "live_update": |
| rag_context = retrieve_docs(translated_query, config.LIVE_VS_PATH) |
| if rag_context: |
| context_info = f"\n\nLatest agricultural updates:\n{rag_context}" |
| elif intent == "low_confidence": |
| rag_context = retrieve_docs(translated_query, config.STATIC_VS_PATH) |
| if rag_context: |
| context_info = f"\n\nRelevant information:\n{rag_context}" |
|
|
| |
| user_message = translated_query + context_info |
| history.append({"role": "user", "content": user_message}) |
|
|
|
|
| messages_for_qwen = build_messages_from_history(history, system_prompt) |
| |
| |
| max_tokens = 256 if intent == "weather" else 400 |
| english_answer = run_qwen(messages_for_qwen, max_new_tokens=max_tokens) |
|
|
| |
| history.append({"role": "assistant", "content": english_answer}) |
| if len(history) > MAX_HISTORY_MESSAGES: |
| history = history[-MAX_HISTORY_MESSAGES:] |
| memory_store.save_history(session_id, history) |
|
|
| |
| final_answer = ( |
| translate_text(english_answer, src_lang="eng_Latn", tgt_lang=lang_label) |
| if lang_label != "eng_Latn" |
| else english_answer |
| ) |
| final_answer = strip_markdown(final_answer) |
| |
| return { |
| "session_id": session_id, |
| "detected_language": SUPPORTED_LANGS.get(lang_label, "Unknown"), |
| "answer": final_answer |
| } |