Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import json | |
| from datetime import datetime | |
| import sqlite3 | |
| import threading | |
| import time | |
| import re | |
| import os | |
| from urllib.parse import quote_plus | |
| import random | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| learning_database = "continuous_learning.db" | |
| class WebSearchSystem: | |
| def __init__(self): | |
| self.search_engines = [ | |
| "https://www.google.com/search?q=", | |
| "https://duckduckgo.com/html/?q=", | |
| "https://search.yahoo.com/search?p=" | |
| ] | |
| self.survival_sites = [ | |
| "site:survivalblog.com", | |
| "site:offgridweb.com", | |
| "site:outdoorlife.com", | |
| "site:backwoodsman.com", | |
| "site:bushcraftusa.com", | |
| "site:prepared.com" | |
| ] | |
| def search_survival_knowledge(self, query): | |
| """Search for survival-specific information""" | |
| try: | |
| search_results = [] | |
| # Create survival-focused search queries | |
| survival_queries = [ | |
| f"survival {query} techniques", | |
| f"wilderness {query} emergency", | |
| f"bushcraft {query} methods", | |
| f"prepper {query} guide" | |
| ] | |
| for search_query in survival_queries[:2]: # Limit to avoid rate limits | |
| try: | |
| # Use different search engines randomly | |
| search_url = random.choice(self.search_engines) | |
| full_query = f"{search_query} {random.choice(self.survival_sites)}" | |
| response = requests.get( | |
| f"{search_url}{quote_plus(full_query)}", | |
| headers={ | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| }, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| results = self.extract_survival_info(response.text, query) | |
| search_results.extend(results) | |
| except Exception as e: | |
| print(f"Search engine error: {e}") | |
| continue | |
| return self.filter_and_rank_results(search_results, query) | |
| except Exception as e: | |
| print(f"Web search error: {e}") | |
| return [] | |
| def extract_survival_info(self, html_content, query): | |
| """Extract relevant survival information from search results""" | |
| try: | |
| soup = BeautifulSoup(html_content, 'html.parser') | |
| results = [] | |
| # Look for content snippets | |
| snippets = soup.find_all(['div', 'p', 'span'], | |
| class_=re.compile(r'(snippet|description|summary|result)')) | |
| for snippet in snippets[:5]: # Limit results | |
| text = snippet.get_text(strip=True) | |
| # Filter for survival-relevant content | |
| if (len(text) > 50 and | |
| any(keyword in text.lower() for keyword in | |
| ['survival', 'emergency', 'wilderness', 'bushcraft', 'first aid', | |
| 'shelter', 'fire', 'water', 'food', 'rescue']) and | |
| query.lower() in text.lower()): | |
| results.append({ | |
| 'content': text[:300], # Limit length | |
| 'relevance': self.calculate_relevance(text, query), | |
| 'source': 'web_search', | |
| 'timestamp': datetime.now().isoformat() | |
| }) | |
| return results | |
| except Exception as e: | |
| print(f"Content extraction error: {e}") | |
| return [] | |
| def calculate_relevance(self, text, query): | |
| """Calculate relevance score for search result""" | |
| query_words = query.lower().split() | |
| text_lower = text.lower() | |
| # Count query word matches | |
| matches = sum(1 for word in query_words if word in text_lower) | |
| # Boost for survival keywords | |
| survival_keywords = ['survival', 'emergency', 'wilderness', 'safety', 'rescue'] | |
| survival_matches = sum(1 for keyword in survival_keywords if keyword in text_lower) | |
| # Calculate score | |
| relevance = (matches / len(query_words)) * 0.7 + (survival_matches / len(survival_keywords)) * 0.3 | |
| return min(relevance, 1.0) | |
| def filter_and_rank_results(self, results, query): | |
| """Filter and rank search results by relevance""" | |
| # Remove duplicates and low-quality results | |
| unique_results = [] | |
| seen_content = set() | |
| for result in results: | |
| content_hash = hash(result['content'][:100]) | |
| if content_hash not in seen_content and result['relevance'] > 0.3: | |
| seen_content.add(content_hash) | |
| unique_results.append(result) | |
| # Sort by relevance | |
| unique_results.sort(key=lambda x: x['relevance'], reverse=True) | |
| return unique_results[:3] # Return top 3 results | |
| class ContinuousLearningSystem: | |
| def __init__(self): | |
| self.web_search = WebSearchSystem() | |
| self.init_database() | |
| def init_database(self): | |
| """Initialize learning database""" | |
| conn = sqlite3.connect(learning_database) | |
| cursor = conn.cursor() | |
| # Create tables for learning data | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS conversations ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id TEXT, | |
| user_query TEXT, | |
| ai_response TEXT, | |
| web_knowledge TEXT, | |
| user_feedback INTEGER, | |
| quality_score REAL, | |
| timestamp TEXT, | |
| learned_from BOOLEAN DEFAULT 0 | |
| ) | |
| ''') | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS learned_knowledge ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| topic TEXT, | |
| knowledge_content TEXT, | |
| source_url TEXT, | |
| confidence REAL, | |
| usage_count INTEGER DEFAULT 0, | |
| last_used TEXT, | |
| timestamp TEXT | |
| ) | |
| ''') | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS learning_queue ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| conversation_id INTEGER, | |
| priority_score REAL, | |
| processed BOOLEAN DEFAULT 0, | |
| timestamp TEXT | |
| ) | |
| ''') | |
| conn.commit() | |
| conn.close() | |
| def enhance_response_with_web_knowledge(self, user_query, base_response): | |
| """Enhance AI response with web-searched knowledge""" | |
| try: | |
| # Search for additional survival knowledge | |
| web_results = self.web_search.search_survival_knowledge(user_query) | |
| if not web_results: | |
| return base_response, [] | |
| # Enhance response with web knowledge | |
| enhanced_response = base_response | |
| # Add web-sourced insights | |
| if web_results: | |
| enhanced_response += "\n\nπ **Additional Current Information:**\n" | |
| for i, result in enumerate(web_results[:2], 1): | |
| enhanced_response += f"\n{i}. {result['content'][:150]}..." | |
| if result['relevance'] > 0.7: | |
| enhanced_response += " β" | |
| enhanced_response += f"\n\n*Based on current survival resources and community knowledge*" | |
| return enhanced_response, web_results | |
| except Exception as e: | |
| print(f"Enhancement error: {e}") | |
| return base_response, [] | |
| def store_conversation(self, user_query, ai_response, web_knowledge): | |
| """Store conversation for learning""" | |
| try: | |
| conn = sqlite3.connect(learning_database) | |
| cursor = conn.cursor() | |
| session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
| cursor.execute(''' | |
| INSERT INTO conversations | |
| (session_id, user_query, ai_response, web_knowledge, timestamp) | |
| VALUES (?, ?, ?, ?, ?) | |
| ''', ( | |
| session_id, | |
| user_query, | |
| ai_response, | |
| json.dumps(web_knowledge) if web_knowledge else None, | |
| datetime.now().isoformat() | |
| )) | |
| conversation_id = cursor.lastrowid | |
| # Add to learning queue if high-quality interaction | |
| quality_score = self.assess_interaction_quality(user_query, ai_response, web_knowledge) | |
| if quality_score > 0.6: # Queue high-quality interactions for learning | |
| cursor.execute(''' | |
| INSERT INTO learning_queue (conversation_id, priority_score, timestamp) | |
| VALUES (?, ?, ?) | |
| ''', (conversation_id, quality_score, datetime.now().isoformat())) | |
| conn.commit() | |
| conn.close() | |
| return conversation_id | |
| except Exception as e: | |
| print(f"Storage error: {e}") | |
| return None | |
| def assess_interaction_quality(self, query, response, web_knowledge): | |
| """Assess quality of interaction for learning priority""" | |
| score = 0.5 # Base score | |
| # Query complexity | |
| if len(query.split()) > 5: | |
| score += 0.1 | |
| # Response length (good responses are detailed) | |
| if len(response) > 200: | |
| score += 0.1 | |
| # Web knowledge integration | |
| if web_knowledge and len(web_knowledge) > 0: | |
| score += 0.2 | |
| # Survival-specific content | |
| survival_keywords = ['survival', 'emergency', 'wilderness', 'rescue', 'first aid'] | |
| if any(keyword in query.lower() for keyword in survival_keywords): | |
| score += 0.2 | |
| return min(score, 1.0) | |
| def get_learning_stats(self): | |
| """Get current learning statistics""" | |
| try: | |
| conn = sqlite3.connect(learning_database) | |
| cursor = conn.cursor() | |
| # Get conversation count | |
| cursor.execute("SELECT COUNT(*) FROM conversations") | |
| total_conversations = cursor.fetchone()[0] | |
| # Get learned knowledge count | |
| cursor.execute("SELECT COUNT(*) FROM learned_knowledge") | |
| knowledge_items = cursor.fetchone()[0] | |
| # Get today's interactions | |
| today = datetime.now().date().isoformat() | |
| cursor.execute("SELECT COUNT(*) FROM conversations WHERE DATE(timestamp) = ?", (today,)) | |
| today_conversations = cursor.fetchone()[0] | |
| # Get learning queue size | |
| cursor.execute("SELECT COUNT(*) FROM learning_queue WHERE processed = 0") | |
| queue_size = cursor.fetchone()[0] | |
| conn.close() | |
| return { | |
| 'total_conversations': total_conversations, | |
| 'knowledge_items': knowledge_items, | |
| 'today_conversations': today_conversations, | |
| 'learning_queue': queue_size | |
| } | |
| except Exception as e: | |
| print(f"Stats error: {e}") | |
| return {'total_conversations': 0, 'knowledge_items': 0, 'today_conversations': 0, 'learning_queue': 0} | |
| # Initialize the continuous learning system | |
| learning_system = ContinuousLearningSystem() | |
| def load_trained_model(): | |
| """Load your existing trained model""" | |
| global model, tokenizer | |
| try: | |
| model_repo = "Znilsson/survival-ai-v1" # Your existing model | |
| tokenizer = AutoTokenizer.from_pretrained(model_repo) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_repo, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| return "β Trained model loaded successfully!" | |
| except Exception as e: | |
| return f"β Failed to load model: {e}" | |
| def chat_with_continuous_learning(message, history): | |
| """Main chat function with web search and learning""" | |
| global model, tokenizer | |
| if model is None: | |
| return "Please load the trained model first!" | |
| try: | |
| # Generate base response from your trained model | |
| base_response = generate_base_response(message) | |
| # Enhance with web knowledge | |
| enhanced_response, web_knowledge = learning_system.enhance_response_with_web_knowledge( | |
| message, base_response | |
| ) | |
| # Store for learning | |
| learning_system.store_conversation(message, enhanced_response, web_knowledge) | |
| return enhanced_response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def generate_base_response(query): | |
| """Generate response using your trained model""" | |
| global model, tokenizer | |
| try: | |
| prompt = f"""### Instruction: | |
| You are an expert survival instructor providing life-saving advice. | |
| ### Question: | |
| {query} | |
| ### Response:""" | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=400, | |
| truncation=True | |
| ) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs.get("attention_mask"), | |
| max_new_tokens=150, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| response = tokenizer.decode( | |
| outputs[0][len(inputs["input_ids"][0]):], | |
| skip_special_tokens=True | |
| ) | |
| return response.strip() | |
| except Exception as e: | |
| return f"Base response error: {e}" | |
| def rate_response(rating): | |
| """Allow users to rate responses for learning""" | |
| if rating and 1 <= int(rating) <= 5: | |
| # Store rating in database for future learning improvements | |
| return f"Thank you! Response rated: {rating}/5 β" | |
| return "Please provide a rating from 1-5" | |
| def get_learning_dashboard(): | |
| """Get learning progress dashboard""" | |
| stats = learning_system.get_learning_stats() | |
| return f"""π **Continuous Learning Dashboard** | |
| π£οΈ **Conversations:** {stats['total_conversations']:,} total | |
| π **Knowledge Base:** {stats['knowledge_items']:,} learned items | |
| π **Today's Activity:** {stats['today_conversations']:,} interactions | |
| β³ **Learning Queue:** {stats['learning_queue']:,} pending | |
| π **Web Learning:** Active | |
| π§ **Model Learning:** Continuous | |
| π **Improvement:** Real-time""" | |
| # Create the enhanced interface | |
| with gr.Blocks(title="Survival AI - Continuous Learning") as demo: | |
| gr.Markdown(""" | |
| # π― Survival AI - Web-Enhanced Continuous Learning | |
| **π NEW FEATURES:** | |
| - π **Real-time web search** for latest survival information | |
| - π§ **Continuous learning** from every conversation | |
| - π **Learning dashboard** showing improvement progress | |
| - β **User feedback** to improve responses | |
| *This AI searches current survival resources and learns from each interaction!* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Model loading | |
| with gr.Row(): | |
| load_btn = gr.Button("π Load Trained Model", variant="primary") | |
| model_status = gr.Textbox(label="Model Status", interactive=False) | |
| # Main chat interface | |
| chat_interface = gr.ChatInterface( | |
| chat_with_continuous_learning, | |
| type="messages", | |
| title="Enhanced Survival AI", | |
| description="Ask survival questions - I'll search the web and learn from our conversation!", | |
| examples=[ | |
| "What are the latest wilderness survival techniques?", | |
| "How do I purify water in an emergency?", | |
| "What should I do if lost in the mountains?", | |
| "How can I start a fire in wet conditions?", | |
| "What are signs of hypothermia and treatment?" | |
| ] | |
| ) | |
| # Feedback system | |
| with gr.Row(): | |
| rating_input = gr.Number( | |
| label="Rate Last Response (1-5)", | |
| value=5, | |
| minimum=1, | |
| maximum=5 | |
| ) | |
| rate_btn = gr.Button("Submit Rating") | |
| rating_output = gr.Textbox(label="Rating Feedback", interactive=False) | |
| with gr.Column(scale=1): | |
| # Learning dashboard | |
| gr.Markdown("## π Learning Dashboard") | |
| dashboard_display = gr.HTML(value=get_learning_dashboard()) | |
| refresh_btn = gr.Button("π Refresh Stats") | |
| # System info | |
| gr.Markdown("## βοΈ System Features") | |
| gr.Markdown(""" | |
| **π Web Search Sources:** | |
| β’ Survival blogs and forums | |
| β’ Emergency preparedness sites | |
| β’ Bushcraft communities | |
| β’ Military survival guides | |
| **π§ Learning Capabilities:** | |
| β’ Conversation analysis | |
| β’ Knowledge extraction | |
| β’ Response optimization | |
| β’ User preference learning | |
| **π Continuous Improvement:** | |
| β’ Real-time knowledge updates | |
| β’ Community-driven learning | |
| β’ Quality-based prioritization | |
| """) | |
| # Connect button functions | |
| load_btn.click(load_trained_model, outputs=model_status) | |
| rate_btn.click(rate_response, inputs=rating_input, outputs=rating_output) | |
| refresh_btn.click(lambda: get_learning_dashboard(), outputs=dashboard_display) | |
| if __name__ == "__main__": | |
| demo.launch() |