|
|
""" |
|
|
FarmEyes N-ATLaS Model Integration (Transformers Version) |
|
|
=========================================================== |
|
|
Uses the FULL N-ATLaS model via HuggingFace Transformers. |
|
|
|
|
|
Model: NCAIR1/N-ATLaS |
|
|
Size: ~16GB |
|
|
Base: Llama-3 8B |
|
|
|
|
|
Languages: English, Hausa, Yoruba, Igbo |
|
|
|
|
|
Powered by Awarri Technologies and the Federal Ministry of |
|
|
Communications, Innovation and Digital Economy. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, List |
|
|
import logging |
|
|
import time |
|
|
from datetime import datetime |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IS_HF_SPACES = os.environ.get("SPACE_ID") is not None |
|
|
|
|
|
|
|
|
HAS_GPU = False |
|
|
GPU_NAME = "None" |
|
|
try: |
|
|
import torch |
|
|
HAS_GPU = torch.cuda.is_available() |
|
|
if HAS_GPU: |
|
|
GPU_NAME = torch.cuda.get_device_name(0) |
|
|
logger.info(f"🎮 GPU detected: {GPU_NAME}") |
|
|
else: |
|
|
logger.info("🖥️ No GPU detected - using CPU") |
|
|
except ImportError: |
|
|
logger.warning("PyTorch not installed") |
|
|
|
|
|
if IS_HF_SPACES: |
|
|
logger.info("🤗 Running on HuggingFace Spaces") |
|
|
else: |
|
|
logger.info("🖥️ Running locally") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LANGUAGE_NAMES = { |
|
|
"en": "English", |
|
|
"ha": "Hausa", |
|
|
"yo": "Yoruba", |
|
|
"ig": "Igbo" |
|
|
} |
|
|
|
|
|
NATIVE_LANGUAGE_NAMES = { |
|
|
"en": "English", |
|
|
"ha": "Yaren Hausa", |
|
|
"yo": "Èdè Yorùbá", |
|
|
"ig": "Asụsụ Igbo" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceAPIClient: |
|
|
"""Compatibility stub.""" |
|
|
def __init__(self, api_token: Optional[str] = None): |
|
|
self.api_token = api_token |
|
|
self._is_available = False |
|
|
|
|
|
def is_available(self) -> bool: |
|
|
return False |
|
|
|
|
|
def generate(self, prompt: str, **kwargs) -> Optional[str]: |
|
|
return None |
|
|
|
|
|
|
|
|
class LocalGGUFModel: |
|
|
"""Compatibility stub.""" |
|
|
def __init__(self, **kwargs): |
|
|
self._is_loaded = False |
|
|
|
|
|
def is_loaded(self) -> bool: |
|
|
return False |
|
|
|
|
|
def load_model(self) -> bool: |
|
|
return False |
|
|
|
|
|
def generate(self, prompt: str, **kwargs) -> Optional[str]: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NATLaSTransformersModel: |
|
|
""" |
|
|
N-ATLaS model using HuggingFace Transformers. |
|
|
|
|
|
Model: NCAIR1/N-ATLaS |
|
|
Base: Llama-3 8B |
|
|
Size: ~16GB |
|
|
""" |
|
|
|
|
|
MODEL_ID = "NCAIR1/N-ATLaS" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str = MODEL_ID, |
|
|
load_on_init: bool = True |
|
|
): |
|
|
self.model_id = model_id |
|
|
|
|
|
self._model = None |
|
|
self._tokenizer = None |
|
|
self._is_loaded = False |
|
|
|
|
|
logger.info(f"NATLaS Config: model={model_id}") |
|
|
|
|
|
if load_on_init: |
|
|
self.load_model() |
|
|
|
|
|
def load_model(self) -> bool: |
|
|
"""Load N-ATLaS model using transformers.""" |
|
|
if self._is_loaded: |
|
|
return True |
|
|
|
|
|
try: |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("📥 LOADING N-ATLaS MODEL") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f" Model: {self.model_id}") |
|
|
logger.info(f" Size: ~16GB") |
|
|
logger.info(" This may take a few minutes on first load...") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
|
|
|
if HAS_GPU: |
|
|
dtype = torch.float16 |
|
|
else: |
|
|
dtype = torch.float32 |
|
|
|
|
|
|
|
|
logger.info("Loading tokenizer...") |
|
|
self._tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_id, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Loading model weights...") |
|
|
self._model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_id, |
|
|
torch_dtype=dtype, |
|
|
device_map="auto" if HAS_GPU else None, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
self._is_loaded = True |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("✅ N-ATLaS MODEL LOADED SUCCESSFULLY!") |
|
|
if HAS_GPU: |
|
|
logger.info(f" Running on GPU: {GPU_NAME}") |
|
|
else: |
|
|
logger.info(" Running on CPU") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to load N-ATLaS model: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
def unload_model(self): |
|
|
"""Unload model to free memory.""" |
|
|
if self._model is not None: |
|
|
del self._model |
|
|
self._model = None |
|
|
if self._tokenizer is not None: |
|
|
del self._tokenizer |
|
|
self._tokenizer = None |
|
|
self._is_loaded = False |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
except: |
|
|
pass |
|
|
|
|
|
logger.info("Model unloaded") |
|
|
|
|
|
@property |
|
|
def is_loaded(self) -> bool: |
|
|
return self._is_loaded |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
system_prompt: Optional[str] = None, |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9, |
|
|
repetition_penalty: float = 1.12 |
|
|
) -> Optional[str]: |
|
|
"""Generate text using N-ATLaS model.""" |
|
|
if not self._is_loaded: |
|
|
if not self.load_model(): |
|
|
return None |
|
|
|
|
|
try: |
|
|
import torch |
|
|
|
|
|
|
|
|
if system_prompt is None: |
|
|
system_prompt = ( |
|
|
"You are a helpful AI assistant for African farmers. " |
|
|
"You help with crop disease diagnosis, treatment advice, and agricultural questions. " |
|
|
"Respond in the same language the user writes in." |
|
|
) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
|
|
|
|
|
|
formatted_prompt = self._tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = self._tokenizer( |
|
|
formatted_prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=4096 |
|
|
) |
|
|
|
|
|
|
|
|
if HAS_GPU: |
|
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self._model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
do_sample=True, |
|
|
pad_token_id=self._tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
input_length = inputs["input_ids"].shape[1] |
|
|
generated_tokens = outputs[0][input_length:] |
|
|
result = self._tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
|
|
if result: |
|
|
result = result.strip() |
|
|
logger.info(f"✅ Generation successful: {len(result)} chars") |
|
|
return result |
|
|
|
|
|
logger.warning("⚠️ Empty response generated") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Generation error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
def translate(self, text: str, target_language: str) -> Optional[str]: |
|
|
"""Translate text to target language.""" |
|
|
if target_language == "en" or not text: |
|
|
return text |
|
|
|
|
|
lang_name = LANGUAGE_NAMES.get(target_language, target_language) |
|
|
|
|
|
prompt = f"Translate the following text to {lang_name}. Only provide the translation, nothing else.\n\nText: {text}" |
|
|
|
|
|
system_prompt = f"You are a professional translator. Translate text accurately to {lang_name}. Only output the translation." |
|
|
|
|
|
result = self.generate( |
|
|
prompt=prompt, |
|
|
system_prompt=system_prompt, |
|
|
max_new_tokens=len(text) * 4, |
|
|
temperature=0.3, |
|
|
repetition_penalty=1.1 |
|
|
) |
|
|
|
|
|
if result: |
|
|
result = result.strip() |
|
|
|
|
|
prefixes_to_remove = [ |
|
|
f"{lang_name}:", |
|
|
f"{lang_name} translation:", |
|
|
"Translation:", |
|
|
"Here is the translation:", |
|
|
"The translation is:", |
|
|
] |
|
|
for prefix in prefixes_to_remove: |
|
|
if result.lower().startswith(prefix.lower()): |
|
|
result = result[len(prefix):].strip() |
|
|
return result |
|
|
|
|
|
return None |
|
|
|
|
|
def translate_batch(self, texts: List[str], target_language: str) -> List[str]: |
|
|
"""Translate multiple texts using individual translations.""" |
|
|
if target_language == "en" or not texts: |
|
|
return texts |
|
|
|
|
|
results = [] |
|
|
for text in texts: |
|
|
if text and text.strip(): |
|
|
translated = self.translate(text, target_language) |
|
|
results.append(translated if translated else text) |
|
|
else: |
|
|
results.append(text) |
|
|
|
|
|
return results |
|
|
|
|
|
def chat_response(self, message: str, context: Dict, language: str = "en") -> Optional[str]: |
|
|
"""Generate chat response with diagnosis context.""" |
|
|
crop = context.get("crop_type", "crop").capitalize() |
|
|
disease = context.get("disease_name", "unknown disease") |
|
|
severity = context.get("severity_level", "unknown") |
|
|
confidence = context.get("confidence", 0) |
|
|
if confidence <= 1: |
|
|
confidence = int(confidence * 100) |
|
|
|
|
|
|
|
|
lang_instructions = { |
|
|
"en": "Respond in English.", |
|
|
"ha": "Respond in Hausa language (Yaren Hausa).", |
|
|
"yo": "Respond in Yoruba language (Èdè Yorùbá).", |
|
|
"ig": "Respond in Igbo language (Asụsụ Igbo)." |
|
|
} |
|
|
lang_instruction = lang_instructions.get(language, "Respond in English.") |
|
|
|
|
|
system_prompt = ( |
|
|
"You are FarmEyes, an AI assistant helping African farmers with crop diseases. " |
|
|
"You provide practical, helpful advice about crop diseases and farming. " |
|
|
f"{lang_instruction}" |
|
|
) |
|
|
|
|
|
prompt = ( |
|
|
f"Current diagnosis information:\n" |
|
|
f"- Crop: {crop}\n" |
|
|
f"- Disease: {disease}\n" |
|
|
f"- Severity: {severity}\n" |
|
|
f"- Confidence: {confidence}%\n\n" |
|
|
f"Farmer's question: {message}\n\n" |
|
|
f"Provide a helpful, practical response about this disease or related farming advice. " |
|
|
f"Keep your response concise (2-3 paragraphs maximum)." |
|
|
) |
|
|
|
|
|
return self.generate( |
|
|
prompt=prompt, |
|
|
system_prompt=system_prompt, |
|
|
max_new_tokens=600, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NATLaSModel: |
|
|
""" |
|
|
Main N-ATLaS model interface. |
|
|
|
|
|
Uses full N-ATLaS model via transformers. |
|
|
""" |
|
|
|
|
|
def __init__(self, auto_load: bool = False): |
|
|
"""Initialize N-ATLaS model.""" |
|
|
|
|
|
|
|
|
self.hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") |
|
|
|
|
|
if self.hf_token: |
|
|
logger.info("✅ HuggingFace token found") |
|
|
|
|
|
try: |
|
|
from huggingface_hub import login |
|
|
login(token=self.hf_token, add_to_git_credential=False) |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not set HF token: {e}") |
|
|
else: |
|
|
logger.warning("⚠️ No HF_TOKEN found - model access may fail") |
|
|
|
|
|
|
|
|
self.model = NATLaSTransformersModel(load_on_init=auto_load) |
|
|
|
|
|
|
|
|
self._cache: Dict[str, str] = {} |
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("✅ NATLaSModel initialized (Full model via Transformers)") |
|
|
logger.info(f" Model: NCAIR1/N-ATLaS (~16GB)") |
|
|
logger.info(f" Model loaded: {'Yes' if self.model.is_loaded else 'No'}") |
|
|
logger.info(f" GPU available: {'Yes - ' + GPU_NAME if HAS_GPU else 'No'}") |
|
|
logger.info(f" HF Token: {'Yes' if self.hf_token else 'No'}") |
|
|
logger.info(f" Running on: {'HuggingFace Spaces' if IS_HF_SPACES else 'Local'}") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
@property |
|
|
def is_loaded(self) -> bool: |
|
|
return self.model.is_loaded |
|
|
|
|
|
@property |
|
|
def is_model_loaded(self) -> bool: |
|
|
"""Alias for is_loaded for compatibility.""" |
|
|
return self.model.is_loaded |
|
|
|
|
|
def load_model(self) -> bool: |
|
|
return self.model.load_model() |
|
|
|
|
|
def ensure_model_loaded(self) -> bool: |
|
|
"""Ensure model is loaded.""" |
|
|
if not self.is_loaded: |
|
|
return self.load_model() |
|
|
return True |
|
|
|
|
|
def translate(self, text: str, target_language: str, use_cache: bool = True) -> str: |
|
|
"""Translate text to target language.""" |
|
|
if target_language == "en" or not text or not text.strip(): |
|
|
return text |
|
|
|
|
|
|
|
|
cache_key = f"{target_language}:{hash(text)}" |
|
|
if use_cache and cache_key in self._cache: |
|
|
logger.info("📦 Using cached translation") |
|
|
return self._cache[cache_key] |
|
|
|
|
|
logger.info(f"🌍 Translating to {LANGUAGE_NAMES.get(target_language, target_language)}...") |
|
|
result = self.model.translate(text, target_language) |
|
|
|
|
|
if result and result != text: |
|
|
|
|
|
if use_cache: |
|
|
self._cache[cache_key] = result |
|
|
|
|
|
if len(self._cache) > 500: |
|
|
keys = list(self._cache.keys())[:100] |
|
|
for k in keys: |
|
|
del self._cache[k] |
|
|
logger.info("✅ Translation successful") |
|
|
return result |
|
|
|
|
|
logger.warning("⚠️ Translation failed - returning original") |
|
|
return text |
|
|
|
|
|
def translate_batch(self, texts: List[str], target_language: str, use_cache: bool = True) -> List[str]: |
|
|
"""Translate multiple texts using individual translations with caching.""" |
|
|
if target_language == "en" or not texts: |
|
|
return texts |
|
|
|
|
|
results = [] |
|
|
for text in texts: |
|
|
if not text or not text.strip(): |
|
|
results.append(text) |
|
|
else: |
|
|
translated = self.translate(text, target_language, use_cache) |
|
|
results.append(translated) |
|
|
|
|
|
return results |
|
|
|
|
|
def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7, **kwargs) -> str: |
|
|
"""Generate text.""" |
|
|
result = self.model.generate( |
|
|
prompt=prompt, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature |
|
|
) |
|
|
return result if result else "" |
|
|
|
|
|
def chat_response(self, message: str, context: Dict, language: str = "en") -> str: |
|
|
"""Generate chat response with context.""" |
|
|
result = self.model.chat_response(message, context, language) |
|
|
if result: |
|
|
return result |
|
|
return "I'm sorry, I couldn't generate a response. Please try again." |
|
|
|
|
|
def load_local_model(self) -> bool: |
|
|
"""Compatibility method.""" |
|
|
return self.load_model() |
|
|
|
|
|
def unload_local_model(self): |
|
|
"""Unload model.""" |
|
|
self.model.unload_model() |
|
|
|
|
|
def get_status(self) -> Dict: |
|
|
return { |
|
|
"model_loaded": self.model.is_loaded, |
|
|
"model_id": self.model.model_id, |
|
|
"model_type": "Full (Transformers)", |
|
|
"model_size": "~16GB", |
|
|
"gpu_available": HAS_GPU, |
|
|
"gpu_name": GPU_NAME if HAS_GPU else None, |
|
|
"hf_token_set": bool(self.hf_token), |
|
|
"cache_size": len(self._cache), |
|
|
"running_on": "HuggingFace Spaces" if IS_HF_SPACES else "Local" |
|
|
} |
|
|
|
|
|
def get_cache_stats(self) -> Dict: |
|
|
"""Get cache statistics.""" |
|
|
return { |
|
|
"size": len(self._cache), |
|
|
"max_size": 500 |
|
|
} |
|
|
|
|
|
def clear_cache(self): |
|
|
self._cache.clear() |
|
|
logger.info("Translation cache cleared") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_model_instance: Optional[NATLaSModel] = None |
|
|
|
|
|
|
|
|
def get_natlas_model(auto_load: bool = False) -> NATLaSModel: |
|
|
"""Get singleton NATLaS model instance.""" |
|
|
global _model_instance |
|
|
|
|
|
if _model_instance is None: |
|
|
_model_instance = NATLaSModel(auto_load=auto_load) |
|
|
|
|
|
return _model_instance |
|
|
|
|
|
|
|
|
def unload_natlas_model(): |
|
|
"""Unload model.""" |
|
|
global _model_instance |
|
|
if _model_instance is not None: |
|
|
_model_instance.unload_local_model() |
|
|
_model_instance = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate_text(text: str, target_language: str) -> str: |
|
|
return get_natlas_model().translate(text, target_language) |
|
|
|
|
|
|
|
|
def translate_batch(texts: List[str], target_language: str) -> List[str]: |
|
|
return get_natlas_model().translate_batch(texts, target_language) |
|
|
|
|
|
|
|
|
def generate_text(prompt: str, max_tokens: int = 512) -> str: |
|
|
return get_natlas_model().generate(prompt, max_tokens=max_tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("=" * 60) |
|
|
print("N-ATLaS Model Test (Transformers)") |
|
|
print("=" * 60) |
|
|
|
|
|
model = get_natlas_model(auto_load=True) |
|
|
|
|
|
print("\nStatus:") |
|
|
for key, value in model.get_status().items(): |
|
|
print(f" {key}: {value}") |
|
|
|
|
|
if model.is_loaded: |
|
|
print("\n--- Testing Translation ---") |
|
|
test_text = "Your plant is healthy" |
|
|
result = model.translate(test_text, "ha") |
|
|
print(f"English: {test_text}") |
|
|
print(f"Hausa: {result}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
|