import os import json import torch import logging from typing import Dict, Any, Optional from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import time logger = logging.getLogger(__name__) class AgriQAAssistant: def __init__(self, model_path: str = "nada013/agriqa-assistant"): self.model_path = model_path self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = None self.tokenizer = None self.config = None self.load_model() def load_model(self): logger.info(f"Loading model from Hugging Face: {self.model_path}") try: # Configuration for the uploaded model self.config = { 'base_model': 'Qwen/Qwen1.5-1.8B-Chat', 'generation_config': { 'max_new_tokens': 512, # Increased for complete responses 'do_sample': True, 'temperature': 0.3, # Lower temperature for more consistent, structured responses 'top_p': 0.85, # Slightly lower for more focused sampling 'top_k': 40, # Lower for more focused responses 'repetition_penalty': 1.2, # Higher penalty to avoid repetition 'length_penalty': 1.1, # Encourage slightly longer, detailed responses 'no_repeat_ngram_size': 3 # Avoid repeating 3-grams } } # Load tokenizer from base model logger.info("Loading tokenizer from base model...") self.tokenizer = AutoTokenizer.from_pretrained( self.config['base_model'], trust_remote_code=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Try to load the model directly from Hugging Face first try: logger.info("Attempting to load model directly from Hugging Face...") self.model = AutoModelForCausalLM.from_pretrained( self.model_path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, attn_implementation="eager", use_flash_attention_2=False ) logger.info("Model loaded directly from Hugging Face successfully") except Exception as direct_load_error: logger.info(f"Direct loading failed: {direct_load_error}") logger.info("Falling back to base model + LoRA adapter approach...") # Load base model first logger.info("Loading base model...") base_model = AutoModelForCausalLM.from_pretrained( self.config['base_model'], torch_dtype=torch.float16, device_map="auto" ) # Try to load the LoRA adapter try: logger.info("Loading LoRA adapter from Hugging Face...") self.model = PeftModel.from_pretrained( base_model, self.model_path, torch_dtype=torch.float16, device_map="auto" ) logger.info("LoRA adapter loaded successfully") except Exception as lora_error: logger.warning(f"LoRA adapter loading failed: {lora_error}") logger.info("Using base model without LoRA adapter...") self.model = base_model # Set to evaluation mode self.model.eval() # Log model information logger.info(f"Model loaded successfully from Hugging Face") logger.info(f"Model type: {type(self.model).__name__}") logger.info(f"Device: {self.device}") # Check if it's a PeftModel if hasattr(self.model, 'peft_config'): logger.info("LoRA adapter configuration:") for adapter_name, config in self.model.peft_config.items(): logger.info(f" - {adapter_name}: {config.target_modules}") except Exception as e: logger.error(f"Failed to load model: {e}") logger.error(f"Model path: {self.model_path}") logger.error(f"Base model: {self.config['base_model']}") import traceback logger.error(f"Traceback: {traceback.format_exc()}") raise def format_prompt(self, question: str) -> str: """Format the question for the model using proper format.""" # Use the tokenizer's chat template if available if hasattr(self.tokenizer, 'apply_chat_template'): try: messages = [ {"role": "system", "content": "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."}, {"role": "user", "content": question} ] formatted_prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return formatted_prompt except Exception as e: logger.warning(f"Failed to use chat template: {e}. Using fallback format.") # Fallback format for Qwen1.5-Chat system_prompt = "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand." formatted_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n" return formatted_prompt def generate_response(self, question: str, max_length: Optional[int] = None) -> Dict[str, Any]: start_time = time.time() try: # Format the prompt prompt = self.format_prompt(question) # Tokenize input inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048 ).to(self.device) # Generation parameters gen_config = self.config['generation_config'].copy() if max_length: gen_config['max_new_tokens'] = max_length # Generate response with torch.no_grad(): outputs = self.model.generate( **inputs, **gen_config, pad_token_id=self.tokenizer.eos_token_id ) # Decode response response = self.tokenizer.decode( outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True ).strip() # Calculate response time response_time = time.time() - start_time return { 'answer': response, 'response_time': response_time, 'model_info': { 'model_name': 'agriqa-assistant', 'model_source': 'Hugging Face', 'model_path': self.model_path, 'base_model': self.config['base_model'] } } except Exception as e: logger.error(f"Error generating response: {e}") return { 'answer': "I apologize, but I encountered an error while processing your question. Please try again.", 'confidence': 0.0, 'response_time': time.time() - start_time, 'error': str(e) } def get_model_info(self) -> Dict[str, Any]: """Get information about the loaded model.""" return { 'model_name': 'agriqa-assistant', 'model_source': 'Hugging Face', 'model_path': self.model_path, 'base_model': self.config['base_model'], 'device': self.device, 'generation_config': self.config['generation_config'] }