Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import torch | |
| import logging | |
| from typing import Dict, Any, Optional | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import time | |
| logger = logging.getLogger(__name__) | |
| class AgriQAAssistant: | |
| def __init__(self, model_path: str = "nada013/agriqa-assistant"): | |
| self.model_path = model_path | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = None | |
| self.tokenizer = None | |
| self.config = None | |
| self.load_model() | |
| def load_model(self): | |
| logger.info(f"Loading model from Hugging Face: {self.model_path}") | |
| try: | |
| # Configuration for the uploaded model | |
| self.config = { | |
| 'base_model': 'Qwen/Qwen1.5-1.8B-Chat', | |
| 'generation_config': { | |
| 'max_new_tokens': 512, # Increased for complete responses | |
| 'do_sample': True, | |
| 'temperature': 0.3, # Lower temperature for more consistent, structured responses | |
| 'top_p': 0.85, # Slightly lower for more focused sampling | |
| 'top_k': 40, # Lower for more focused responses | |
| 'repetition_penalty': 1.2, # Higher penalty to avoid repetition | |
| 'length_penalty': 1.1, # Encourage slightly longer, detailed responses | |
| 'no_repeat_ngram_size': 3 # Avoid repeating 3-grams | |
| } | |
| } | |
| # Load tokenizer from base model | |
| logger.info("Loading tokenizer from base model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.config['base_model'], | |
| trust_remote_code=True | |
| ) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Try to load the model directly from Hugging Face first | |
| try: | |
| logger.info("Attempting to load model directly from Hugging Face...") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| attn_implementation="eager", | |
| use_flash_attention_2=False | |
| ) | |
| logger.info("Model loaded directly from Hugging Face successfully") | |
| except Exception as direct_load_error: | |
| logger.info(f"Direct loading failed: {direct_load_error}") | |
| logger.info("Falling back to base model + LoRA adapter approach...") | |
| # Load base model first | |
| logger.info("Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| self.config['base_model'], | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # Try to load the LoRA adapter | |
| try: | |
| logger.info("Loading LoRA adapter from Hugging Face...") | |
| self.model = PeftModel.from_pretrained( | |
| base_model, | |
| self.model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| logger.info("LoRA adapter loaded successfully") | |
| except Exception as lora_error: | |
| logger.warning(f"LoRA adapter loading failed: {lora_error}") | |
| logger.info("Using base model without LoRA adapter...") | |
| self.model = base_model | |
| # Set to evaluation mode | |
| self.model.eval() | |
| # Log model information | |
| logger.info(f"Model loaded successfully from Hugging Face") | |
| logger.info(f"Model type: {type(self.model).__name__}") | |
| logger.info(f"Device: {self.device}") | |
| # Check if it's a PeftModel | |
| if hasattr(self.model, 'peft_config'): | |
| logger.info("LoRA adapter configuration:") | |
| for adapter_name, config in self.model.peft_config.items(): | |
| logger.info(f" - {adapter_name}: {config.target_modules}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| logger.error(f"Model path: {self.model_path}") | |
| logger.error(f"Base model: {self.config['base_model']}") | |
| import traceback | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| raise | |
| def format_prompt(self, question: str) -> str: | |
| """Format the question for the model using proper format.""" | |
| # Use the tokenizer's chat template if available | |
| if hasattr(self.tokenizer, 'apply_chat_template'): | |
| try: | |
| messages = [ | |
| {"role": "system", "content": "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."}, | |
| {"role": "user", "content": question} | |
| ] | |
| formatted_prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return formatted_prompt | |
| except Exception as e: | |
| logger.warning(f"Failed to use chat template: {e}. Using fallback format.") | |
| # Fallback format for Qwen1.5-Chat | |
| system_prompt = "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand." | |
| formatted_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n" | |
| return formatted_prompt | |
| def generate_response(self, question: str, max_length: Optional[int] = None) -> Dict[str, Any]: | |
| start_time = time.time() | |
| try: | |
| # Format the prompt | |
| prompt = self.format_prompt(question) | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048 | |
| ).to(self.device) | |
| # Generation parameters | |
| gen_config = self.config['generation_config'].copy() | |
| if max_length: | |
| gen_config['max_new_tokens'] = max_length | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| **gen_config, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| response = self.tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| # Calculate response time | |
| response_time = time.time() - start_time | |
| return { | |
| 'answer': response, | |
| 'response_time': response_time, | |
| 'model_info': { | |
| 'model_name': 'agriqa-assistant', | |
| 'model_source': 'Hugging Face', | |
| 'model_path': self.model_path, | |
| 'base_model': self.config['base_model'] | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| return { | |
| 'answer': "I apologize, but I encountered an error while processing your question. Please try again.", | |
| 'confidence': 0.0, | |
| 'response_time': time.time() - start_time, | |
| 'error': str(e) | |
| } | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get information about the loaded model.""" | |
| return { | |
| 'model_name': 'agriqa-assistant', | |
| 'model_source': 'Hugging Face', | |
| 'model_path': self.model_path, | |
| 'base_model': self.config['base_model'], | |
| 'device': self.device, | |
| 'generation_config': self.config['generation_config'] | |
| } |