Znilsson's picture
Update app.py
c79bd1e verified
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import requests
from bs4 import BeautifulSoup
import json
from datetime import datetime
import sqlite3
import threading
import time
import re
import os
from urllib.parse import quote_plus
import random
# Global variables
model = None
tokenizer = None
learning_database = "continuous_learning.db"
class WebSearchSystem:
def __init__(self):
self.search_engines = [
"https://www.google.com/search?q=",
"https://duckduckgo.com/html/?q=",
"https://search.yahoo.com/search?p="
]
self.survival_sites = [
"site:survivalblog.com",
"site:offgridweb.com",
"site:outdoorlife.com",
"site:backwoodsman.com",
"site:bushcraftusa.com",
"site:prepared.com"
]
def search_survival_knowledge(self, query):
"""Search for survival-specific information"""
try:
search_results = []
# Create survival-focused search queries
survival_queries = [
f"survival {query} techniques",
f"wilderness {query} emergency",
f"bushcraft {query} methods",
f"prepper {query} guide"
]
for search_query in survival_queries[:2]: # Limit to avoid rate limits
try:
# Use different search engines randomly
search_url = random.choice(self.search_engines)
full_query = f"{search_query} {random.choice(self.survival_sites)}"
response = requests.get(
f"{search_url}{quote_plus(full_query)}",
headers={
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
},
timeout=10
)
if response.status_code == 200:
results = self.extract_survival_info(response.text, query)
search_results.extend(results)
except Exception as e:
print(f"Search engine error: {e}")
continue
return self.filter_and_rank_results(search_results, query)
except Exception as e:
print(f"Web search error: {e}")
return []
def extract_survival_info(self, html_content, query):
"""Extract relevant survival information from search results"""
try:
soup = BeautifulSoup(html_content, 'html.parser')
results = []
# Look for content snippets
snippets = soup.find_all(['div', 'p', 'span'],
class_=re.compile(r'(snippet|description|summary|result)'))
for snippet in snippets[:5]: # Limit results
text = snippet.get_text(strip=True)
# Filter for survival-relevant content
if (len(text) > 50 and
any(keyword in text.lower() for keyword in
['survival', 'emergency', 'wilderness', 'bushcraft', 'first aid',
'shelter', 'fire', 'water', 'food', 'rescue']) and
query.lower() in text.lower()):
results.append({
'content': text[:300], # Limit length
'relevance': self.calculate_relevance(text, query),
'source': 'web_search',
'timestamp': datetime.now().isoformat()
})
return results
except Exception as e:
print(f"Content extraction error: {e}")
return []
def calculate_relevance(self, text, query):
"""Calculate relevance score for search result"""
query_words = query.lower().split()
text_lower = text.lower()
# Count query word matches
matches = sum(1 for word in query_words if word in text_lower)
# Boost for survival keywords
survival_keywords = ['survival', 'emergency', 'wilderness', 'safety', 'rescue']
survival_matches = sum(1 for keyword in survival_keywords if keyword in text_lower)
# Calculate score
relevance = (matches / len(query_words)) * 0.7 + (survival_matches / len(survival_keywords)) * 0.3
return min(relevance, 1.0)
def filter_and_rank_results(self, results, query):
"""Filter and rank search results by relevance"""
# Remove duplicates and low-quality results
unique_results = []
seen_content = set()
for result in results:
content_hash = hash(result['content'][:100])
if content_hash not in seen_content and result['relevance'] > 0.3:
seen_content.add(content_hash)
unique_results.append(result)
# Sort by relevance
unique_results.sort(key=lambda x: x['relevance'], reverse=True)
return unique_results[:3] # Return top 3 results
class ContinuousLearningSystem:
def __init__(self):
self.web_search = WebSearchSystem()
self.init_database()
def init_database(self):
"""Initialize learning database"""
conn = sqlite3.connect(learning_database)
cursor = conn.cursor()
# Create tables for learning data
cursor.execute('''
CREATE TABLE IF NOT EXISTS conversations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
user_query TEXT,
ai_response TEXT,
web_knowledge TEXT,
user_feedback INTEGER,
quality_score REAL,
timestamp TEXT,
learned_from BOOLEAN DEFAULT 0
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS learned_knowledge (
id INTEGER PRIMARY KEY AUTOINCREMENT,
topic TEXT,
knowledge_content TEXT,
source_url TEXT,
confidence REAL,
usage_count INTEGER DEFAULT 0,
last_used TEXT,
timestamp TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS learning_queue (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id INTEGER,
priority_score REAL,
processed BOOLEAN DEFAULT 0,
timestamp TEXT
)
''')
conn.commit()
conn.close()
def enhance_response_with_web_knowledge(self, user_query, base_response):
"""Enhance AI response with web-searched knowledge"""
try:
# Search for additional survival knowledge
web_results = self.web_search.search_survival_knowledge(user_query)
if not web_results:
return base_response, []
# Enhance response with web knowledge
enhanced_response = base_response
# Add web-sourced insights
if web_results:
enhanced_response += "\n\n🌐 **Additional Current Information:**\n"
for i, result in enumerate(web_results[:2], 1):
enhanced_response += f"\n{i}. {result['content'][:150]}..."
if result['relevance'] > 0.7:
enhanced_response += " ⭐"
enhanced_response += f"\n\n*Based on current survival resources and community knowledge*"
return enhanced_response, web_results
except Exception as e:
print(f"Enhancement error: {e}")
return base_response, []
def store_conversation(self, user_query, ai_response, web_knowledge):
"""Store conversation for learning"""
try:
conn = sqlite3.connect(learning_database)
cursor = conn.cursor()
session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
cursor.execute('''
INSERT INTO conversations
(session_id, user_query, ai_response, web_knowledge, timestamp)
VALUES (?, ?, ?, ?, ?)
''', (
session_id,
user_query,
ai_response,
json.dumps(web_knowledge) if web_knowledge else None,
datetime.now().isoformat()
))
conversation_id = cursor.lastrowid
# Add to learning queue if high-quality interaction
quality_score = self.assess_interaction_quality(user_query, ai_response, web_knowledge)
if quality_score > 0.6: # Queue high-quality interactions for learning
cursor.execute('''
INSERT INTO learning_queue (conversation_id, priority_score, timestamp)
VALUES (?, ?, ?)
''', (conversation_id, quality_score, datetime.now().isoformat()))
conn.commit()
conn.close()
return conversation_id
except Exception as e:
print(f"Storage error: {e}")
return None
def assess_interaction_quality(self, query, response, web_knowledge):
"""Assess quality of interaction for learning priority"""
score = 0.5 # Base score
# Query complexity
if len(query.split()) > 5:
score += 0.1
# Response length (good responses are detailed)
if len(response) > 200:
score += 0.1
# Web knowledge integration
if web_knowledge and len(web_knowledge) > 0:
score += 0.2
# Survival-specific content
survival_keywords = ['survival', 'emergency', 'wilderness', 'rescue', 'first aid']
if any(keyword in query.lower() for keyword in survival_keywords):
score += 0.2
return min(score, 1.0)
def get_learning_stats(self):
"""Get current learning statistics"""
try:
conn = sqlite3.connect(learning_database)
cursor = conn.cursor()
# Get conversation count
cursor.execute("SELECT COUNT(*) FROM conversations")
total_conversations = cursor.fetchone()[0]
# Get learned knowledge count
cursor.execute("SELECT COUNT(*) FROM learned_knowledge")
knowledge_items = cursor.fetchone()[0]
# Get today's interactions
today = datetime.now().date().isoformat()
cursor.execute("SELECT COUNT(*) FROM conversations WHERE DATE(timestamp) = ?", (today,))
today_conversations = cursor.fetchone()[0]
# Get learning queue size
cursor.execute("SELECT COUNT(*) FROM learning_queue WHERE processed = 0")
queue_size = cursor.fetchone()[0]
conn.close()
return {
'total_conversations': total_conversations,
'knowledge_items': knowledge_items,
'today_conversations': today_conversations,
'learning_queue': queue_size
}
except Exception as e:
print(f"Stats error: {e}")
return {'total_conversations': 0, 'knowledge_items': 0, 'today_conversations': 0, 'learning_queue': 0}
# Initialize the continuous learning system
learning_system = ContinuousLearningSystem()
@spaces.GPU(duration=60)
def load_trained_model():
"""Load your existing trained model"""
global model, tokenizer
try:
model_repo = "Znilsson/survival-ai-v1" # Your existing model
tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModelForCausalLM.from_pretrained(
model_repo,
torch_dtype=torch.float16,
device_map="auto"
)
return "βœ… Trained model loaded successfully!"
except Exception as e:
return f"❌ Failed to load model: {e}"
def chat_with_continuous_learning(message, history):
"""Main chat function with web search and learning"""
global model, tokenizer
if model is None:
return "Please load the trained model first!"
try:
# Generate base response from your trained model
base_response = generate_base_response(message)
# Enhance with web knowledge
enhanced_response, web_knowledge = learning_system.enhance_response_with_web_knowledge(
message, base_response
)
# Store for learning
learning_system.store_conversation(message, enhanced_response, web_knowledge)
return enhanced_response
except Exception as e:
return f"Error: {str(e)}"
def generate_base_response(query):
"""Generate response using your trained model"""
global model, tokenizer
try:
prompt = f"""### Instruction:
You are an expert survival instructor providing life-saving advice.
### Question:
{query}
### Response:"""
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=400,
truncation=True
)
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
max_new_tokens=150,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
response = tokenizer.decode(
outputs[0][len(inputs["input_ids"][0]):],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
return f"Base response error: {e}"
def rate_response(rating):
"""Allow users to rate responses for learning"""
if rating and 1 <= int(rating) <= 5:
# Store rating in database for future learning improvements
return f"Thank you! Response rated: {rating}/5 ⭐"
return "Please provide a rating from 1-5"
def get_learning_dashboard():
"""Get learning progress dashboard"""
stats = learning_system.get_learning_stats()
return f"""πŸ“Š **Continuous Learning Dashboard**
πŸ—£οΈ **Conversations:** {stats['total_conversations']:,} total
πŸ“š **Knowledge Base:** {stats['knowledge_items']:,} learned items
πŸ“… **Today's Activity:** {stats['today_conversations']:,} interactions
⏳ **Learning Queue:** {stats['learning_queue']:,} pending
🌐 **Web Learning:** Active
🧠 **Model Learning:** Continuous
πŸ“ˆ **Improvement:** Real-time"""
# Create the enhanced interface
with gr.Blocks(title="Survival AI - Continuous Learning") as demo:
gr.Markdown("""
# 🎯 Survival AI - Web-Enhanced Continuous Learning
**πŸš€ NEW FEATURES:**
- 🌐 **Real-time web search** for latest survival information
- 🧠 **Continuous learning** from every conversation
- πŸ“Š **Learning dashboard** showing improvement progress
- ⭐ **User feedback** to improve responses
*This AI searches current survival resources and learns from each interaction!*
""")
with gr.Row():
with gr.Column(scale=2):
# Model loading
with gr.Row():
load_btn = gr.Button("πŸ”„ Load Trained Model", variant="primary")
model_status = gr.Textbox(label="Model Status", interactive=False)
# Main chat interface
chat_interface = gr.ChatInterface(
chat_with_continuous_learning,
type="messages",
title="Enhanced Survival AI",
description="Ask survival questions - I'll search the web and learn from our conversation!",
examples=[
"What are the latest wilderness survival techniques?",
"How do I purify water in an emergency?",
"What should I do if lost in the mountains?",
"How can I start a fire in wet conditions?",
"What are signs of hypothermia and treatment?"
]
)
# Feedback system
with gr.Row():
rating_input = gr.Number(
label="Rate Last Response (1-5)",
value=5,
minimum=1,
maximum=5
)
rate_btn = gr.Button("Submit Rating")
rating_output = gr.Textbox(label="Rating Feedback", interactive=False)
with gr.Column(scale=1):
# Learning dashboard
gr.Markdown("## πŸ“Š Learning Dashboard")
dashboard_display = gr.HTML(value=get_learning_dashboard())
refresh_btn = gr.Button("πŸ”„ Refresh Stats")
# System info
gr.Markdown("## βš™οΈ System Features")
gr.Markdown("""
**🌐 Web Search Sources:**
β€’ Survival blogs and forums
β€’ Emergency preparedness sites
β€’ Bushcraft communities
β€’ Military survival guides
**🧠 Learning Capabilities:**
β€’ Conversation analysis
β€’ Knowledge extraction
β€’ Response optimization
β€’ User preference learning
**πŸ“ˆ Continuous Improvement:**
β€’ Real-time knowledge updates
β€’ Community-driven learning
β€’ Quality-based prioritization
""")
# Connect button functions
load_btn.click(load_trained_model, outputs=model_status)
rate_btn.click(rate_response, inputs=rating_input, outputs=rating_output)
refresh_btn.click(lambda: get_learning_dashboard(), outputs=dashboard_display)
if __name__ == "__main__":
demo.launch()