|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer |
|
|
from tqdm import tqdm |
|
|
import math |
|
|
import speech_recognition as sr |
|
|
import pyttsx3 |
|
|
from googlesearch import search |
|
|
import warnings |
|
|
from typing import List, Dict, Union |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class WebSearchWrapper: |
|
|
"""Wrapper for web search with caching""" |
|
|
def __init__(self, cache_size: int = 100): |
|
|
self.cache: Dict[str, List[str]] = {} |
|
|
self.cache_size = cache_size |
|
|
|
|
|
def search(self, query: str, num_results: int = 3) -> List[str]: |
|
|
"""Perform web search with caching""" |
|
|
if query.lower() in self.cache: |
|
|
return self.cache[query.lower()] |
|
|
|
|
|
try: |
|
|
search_results = list(search(query, num_results=num_results, stop=num_results, pause=2)) |
|
|
self._add_to_cache(query, search_results) |
|
|
return search_results |
|
|
except Exception as e: |
|
|
print(f"Web search error: {e}") |
|
|
return [] |
|
|
|
|
|
def _add_to_cache(self, query: str, results: List[str]): |
|
|
"""Add results to cache with LRU eviction policy""" |
|
|
if len(self.cache) >= self.cache_size: |
|
|
self.cache.pop(next(iter(self.cache))) |
|
|
self.cache[query.lower()] = results |
|
|
|
|
|
class FullChatDataset(Dataset): |
|
|
def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=256): |
|
|
self.datasets = [] |
|
|
|
|
|
for name in dataset_names: |
|
|
try: |
|
|
dataset = load_dataset(name, split="train") |
|
|
self.datasets.append(dataset) |
|
|
except Exception as e: |
|
|
print(f"Failed to load dataset {name}: {e}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
|
self.max_length = max_length |
|
|
|
|
|
def __len__(self): |
|
|
return sum(len(d) for d in self.datasets) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
for dataset in self.datasets: |
|
|
if idx < len(dataset): |
|
|
item = dataset[idx] |
|
|
break |
|
|
idx -= len(dataset) |
|
|
|
|
|
if 'dialog' in item: |
|
|
dialog = item['dialog'] |
|
|
elif 'messages' in item: |
|
|
dialog = [msg['text'] for msg in item['messages']] |
|
|
else: |
|
|
dialog = [v for k, v in item.items() if isinstance(v, str)] |
|
|
|
|
|
context = " [SEP] ".join(dialog[:-1]) |
|
|
response = dialog[-1] |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
context, |
|
|
text_pair=response, |
|
|
max_length=self.max_length, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': inputs['input_ids'].flatten(), |
|
|
'attention_mask': inputs['attention_mask'].flatten(), |
|
|
'labels': inputs['input_ids'].flatten() |
|
|
} |
|
|
|
|
|
class SimpleTransformerModel(nn.Module): |
|
|
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(vocab_size, d_model) |
|
|
self.pos_encoder = PositionalEncoding(d_model) |
|
|
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) |
|
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) |
|
|
self.fc = nn.Linear(d_model, vocab_size) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
x = self.embedding(x) |
|
|
x = self.pos_encoder(x) |
|
|
x = self.transformer(x, mask) |
|
|
return self.fc(x) |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
def __init__(self, d_model, max_len=500): |
|
|
super().__init__() |
|
|
position = torch.arange(max_len).unsqueeze(1) |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) |
|
|
pe = torch.zeros(max_len, d_model) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
self.register_buffer('pe', pe) |
|
|
|
|
|
def forward(self, x): |
|
|
return x + self.pe[:x.size(1)] |
|
|
|
|
|
class VoiceInterface: |
|
|
def __init__(self): |
|
|
self.recognizer = sr.Recognizer() |
|
|
self.engine = pyttsx3.init() |
|
|
|
|
|
def listen(self) -> Union[str, None]: |
|
|
with sr.Microphone() as source: |
|
|
print("Listening...") |
|
|
audio = self.recognizer.listen(source) |
|
|
try: |
|
|
text = self.recognizer.recognize_google(audio) |
|
|
print(f"You said: {text}") |
|
|
return text |
|
|
except Exception as e: |
|
|
print(f"Error recognizing speech: {e}") |
|
|
return None |
|
|
|
|
|
def speak(self, text: str): |
|
|
print(f"Bot: {text}") |
|
|
self.engine.say(text) |
|
|
self.engine.runAndWait() |
|
|
|
|
|
class ChatBot: |
|
|
def __init__(self): |
|
|
self.dataset = FullChatDataset() |
|
|
self.model = SimpleTransformerModel(len(self.dataset.tokenizer)) |
|
|
self.voice_interface = VoiceInterface() |
|
|
self.web_searcher = WebSearchWrapper() |
|
|
|
|
|
def train(self, epochs=3, lr=3e-4): |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
self.model = self.model.to(device) |
|
|
criterion = nn.CrossEntropyLoss(ignore_index=0) |
|
|
optimizer = optim.Adam(self.model.parameters(), lr=lr) |
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=8, shuffle=True) |
|
|
|
|
|
for epoch in range(epochs): |
|
|
self.model.train() |
|
|
total_loss = 0 |
|
|
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") |
|
|
|
|
|
for batch in pbar: |
|
|
inputs = batch['input_ids'].to(device) |
|
|
masks = batch['attention_mask'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = self.model(inputs, masks) |
|
|
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1)) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
pbar.set_postfix({'loss': loss.item()}) |
|
|
|
|
|
print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}") |
|
|
|
|
|
def generate_response(self, prompt: str, max_length: int = 100, use_web: bool = True) -> str: |
|
|
device = next(self.model.parameters()).device |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
if use_web and self._needs_web_search(prompt): |
|
|
web_results = self.web_searcher.search(prompt) |
|
|
if web_results: |
|
|
prompt = f"Web context: {', '.join(web_results[:3])}. User question: {prompt}" |
|
|
|
|
|
inputs = self.dataset.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
max_length=256, |
|
|
truncation=True, |
|
|
padding='max_length' |
|
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
input_ids=inputs['input_ids'], |
|
|
attention_mask=inputs['attention_mask'], |
|
|
max_length=max_length, |
|
|
do_sample=True, |
|
|
top_k=50, |
|
|
top_p=0.95, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
response = self.dataset.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return response |
|
|
|
|
|
def _needs_web_search(self, text: str) -> bool: |
|
|
"""Determine if a query needs web search""" |
|
|
question_words = ['what', 'when', 'where', 'who', 'why', 'how', 'which', '?'] |
|
|
return any(word in text.lower() for word in question_words) |