image-forgery-assistant / agent_core.py
Denisijcu's picture
update
59ac395 verified
"""
Agent Core Logic - Hugging Face Free Models
Uses free Inference API - no API keys needed!
"""
from huggingface_hub import InferenceClient
class HuggingFaceLLM:
"""
LLM using free Hugging Face Inference API
No API key needed for public models!
"""
def __init__(self, model_name: str = None):
# Free models available on HF
self.available_models = [
"meta-llama/Meta-Llama-3-8B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.2",
"HuggingFaceH4/zephyr-7b-beta",
"microsoft/Phi-3-mini-4k-instruct",
]
self.model = model_name or self.available_models[0]
self.client = InferenceClient()
self.call_count = 0
self.total_time = 0
print(f"✅ LLM initialized: {self.model}")
def call(self, prompt: str, max_tokens: int = 2000, temperature: float = 0.7) -> str:
"""Call Hugging Face Inference API"""
import time
try:
self.call_count += 1
start_time = time.time()
# Try the primary model
try:
response = self.client.text_generation(
prompt=prompt,
model=self.model,
max_new_tokens=max_tokens,
temperature=temperature,
return_full_text=False
)
elapsed = time.time() - start_time
self.total_time += elapsed
return response
except Exception as e:
print(f"Primary model failed, trying fallback...")
for fallback_model in self.available_models[1:]:
try:
response = self.client.text_generation(
prompt=prompt,
model=fallback_model,
max_new_tokens=max_tokens,
temperature=temperature,
return_full_text=False
)
self.model = fallback_model
elapsed = time.time() - start_time
self.total_time += elapsed
return response
except:
continue
raise Exception("All models unavailable. Please try again later.")
except Exception as e:
return f"❌ Error: {str(e)}\n\nPlease try again in a moment."
def get_stats(self) -> dict:
"""Get LLM usage statistics"""
avg_time = self.total_time / self.call_count if self.call_count > 0 else 0
return {
"total_calls": self.call_count,
"total_time": round(self.total_time, 2),
"avg_time": round(avg_time, 2),
"current_model": self.model
}
class ImageForgeryAgent:
"""Main agent for Image Forgery Detection"""
def __init__(self, llm_instance):
self.llm = llm_instance
self.config = {
"user_name": "Denis",
"competition": "Recod.ai/LUC Scientific Image Forgery Detection",
"current_model": "EfficientNet-B4 UNet++",
"current_score": 0.303,
"target_score": 0.350,
"version": "1.0.0"
}
self.query_count = 0
self.successful_queries = 0
def run(self, query: str) -> str:
"""Process query and return response"""
self.query_count += 1
try:
system_context = f"""You are an expert AI assistant helping {self.config['user_name']} with the {self.config['competition']}.
Current Setup:
- Model: {self.config['current_model']}
- Score: {self.config['current_score']}
- Target: {self.config['target_score']}
Your Tools & Expertise:
1. suggest_improvements - Provide specific model optimization tips
2. compare_architectures - Compare different neural network architectures
3. debug_rle_masks - Help debug RLE encoding issues
4. create_strategy_plan - Create actionable competition strategies
5. analyze_discussions - Provide insights on competition trends
Instructions:
- Be specific and actionable
- Focus on practical Kaggle strategies
- Provide code examples when relevant
- Be concise but comprehensive
- Prioritize high-impact suggestions
User Question: {query}
Your Response:"""
response = self.llm.call(system_context, max_tokens=1500, temperature=0.7)
# Clean up response
response = response.strip()
# If response is too short, add context
if len(response) < 50:
response = "I'm processing your query. " + response
self.successful_queries += 1
return response
except Exception as e:
return f"❌ Error processing query: {str(e)}"
def get_stats(self) -> dict:
"""Get agent statistics"""
success_rate = (self.successful_queries / self.query_count * 100) if self.query_count > 0 else 0
return {
"total_queries": self.query_count,
"successful_queries": self.successful_queries,
"success_rate": round(success_rate, 1),
"llm_stats": self.llm.get_stats()
}
def get_llm():
"""Get LLM instance - no API key needed!"""
return HuggingFaceLLM()