othdu's picture
Upload 5 files
95d2da1 verified
import os
import json
import torch
import logging
from typing import Dict, Any, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import time
logger = logging.getLogger(__name__)
class AgriQAAssistant:
def __init__(self, model_path: str = "nada013/agriqa-assistant"):
self.model_path = model_path
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
self.config = None
self.load_model()
def load_model(self):
logger.info(f"Loading model from Hugging Face: {self.model_path}")
try:
# Configuration for the uploaded model
self.config = {
'base_model': 'Qwen/Qwen1.5-1.8B-Chat',
'generation_config': {
'max_new_tokens': 512, # Increased for complete responses
'do_sample': True,
'temperature': 0.3, # Lower temperature for more consistent, structured responses
'top_p': 0.85, # Slightly lower for more focused sampling
'top_k': 40, # Lower for more focused responses
'repetition_penalty': 1.2, # Higher penalty to avoid repetition
'length_penalty': 1.1, # Encourage slightly longer, detailed responses
'no_repeat_ngram_size': 3 # Avoid repeating 3-grams
}
}
# Load tokenizer from base model
logger.info("Loading tokenizer from base model...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.config['base_model'],
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
<<<<<<< HEAD
# Try to load the model directly from Hugging Face first
try:
logger.info("Attempting to load model directly from Hugging Face...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
attn_implementation="eager",
use_flash_attention_2=False
)
logger.info("Model loaded directly from Hugging Face successfully")
except Exception as direct_load_error:
logger.info(f"Direct loading failed: {direct_load_error}")
logger.info("Falling back to base model + LoRA adapter approach...")
# Load base model first
logger.info("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
self.config['base_model'],
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
attn_implementation="eager",
use_flash_attention_2=False
)
# Try to load the LoRA adapter
try:
logger.info("Loading LoRA adapter from Hugging Face...")
self.model = PeftModel.from_pretrained(
base_model,
self.model_path,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="eager",
use_flash_attention_2=False
)
logger.info("LoRA adapter loaded successfully")
except Exception as lora_error:
logger.warning(f"LoRA adapter loading failed: {lora_error}")
logger.info("Using base model without LoRA adapter...")
self.model = base_model
=======
# Load base model first
logger.info("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
self.config['base_model'],
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
# Load the LoRA adapter from Hugging Face
logger.info("Loading LoRA adapter from Hugging Face...")
self.model = PeftModel.from_pretrained(
base_model,
self.model_path,
torch_dtype=torch.float16,
device_map="auto",
)
>>>>>>> 3b1d9d4700da14631c2d7f96e38c9e460a1a4dd0
# Set to evaluation mode
self.model.eval()
<<<<<<< HEAD
# Log model information
logger.info(f"Model loaded successfully from Hugging Face")
logger.info(f"Model type: {type(self.model).__name__}")
logger.info(f"Device: {self.device}")
# Check if it's a PeftModel
if hasattr(self.model, 'peft_config'):
logger.info("LoRA adapter configuration:")
for adapter_name, config in self.model.peft_config.items():
logger.info(f" - {adapter_name}: {config.target_modules}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
logger.error(f"Model path: {self.model_path}")
logger.error(f"Base model: {self.config['base_model']}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
=======
logger.info("Model loaded successfully from Hugging Face")
except Exception as e:
logger.error(f"Failed to load model: {e}")
>>>>>>> 3b1d9d4700da14631c2d7f96e38c9e460a1a4dd0
raise
def format_prompt(self, question: str) -> str:
"""Format the question for the model using proper format."""
# Use the tokenizer's chat template if available
if hasattr(self.tokenizer, 'apply_chat_template'):
try:
messages = [
{"role": "system", "content": "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."},
{"role": "user", "content": question}
]
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
return formatted_prompt
except Exception as e:
logger.warning(f"Failed to use chat template: {e}. Using fallback format.")
# Fallback format for Qwen1.5-Chat
system_prompt = "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."
formatted_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
return formatted_prompt
def generate_response(self, question: str, max_length: Optional[int] = None) -> Dict[str, Any]:
start_time = time.time()
try:
# Format the prompt
prompt = self.format_prompt(question)
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048
).to(self.device)
# Generation parameters
gen_config = self.config['generation_config'].copy()
if max_length:
gen_config['max_new_tokens'] = max_length
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
**gen_config,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode response
response = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
).strip()
# Calculate response time
response_time = time.time() - start_time
return {
'answer': response,
'response_time': response_time,
'model_info': {
'model_name': 'agriqa-assistant',
'model_source': 'Hugging Face',
'model_path': self.model_path,
'base_model': self.config['base_model']
}
}
except Exception as e:
logger.error(f"Error generating response: {e}")
return {
'answer': "I apologize, but I encountered an error while processing your question. Please try again.",
'confidence': 0.0,
'response_time': time.time() - start_time,
'error': str(e)
}
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model."""
return {
'model_name': 'agriqa-assistant',
'model_source': 'Hugging Face',
'model_path': self.model_path,
'base_model': self.config['base_model'],
'device': self.device,
'generation_config': self.config['generation_config']
}