Text Generation
Transformers
English
Russian
legal
CyberFuture-3 / model.py
SkillForge45's picture
Create model.py
088ce82 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm
import math
import speech_recognition as sr
import pyttsx3
from googlesearch import search
import warnings
from typing import List, Dict, Union
# Ignore warnings
warnings.filterwarnings("ignore")
class WebSearchWrapper:
"""Wrapper for web search with caching"""
def __init__(self, cache_size: int = 100):
self.cache: Dict[str, List[str]] = {}
self.cache_size = cache_size
def search(self, query: str, num_results: int = 3) -> List[str]:
"""Perform web search with caching"""
if query.lower() in self.cache:
return self.cache[query.lower()]
try:
search_results = list(search(query, num_results=num_results, stop=num_results, pause=2))
self._add_to_cache(query, search_results)
return search_results
except Exception as e:
print(f"Web search error: {e}")
return []
def _add_to_cache(self, query: str, results: List[str]):
"""Add results to cache with LRU eviction policy"""
if len(self.cache) >= self.cache_size:
self.cache.pop(next(iter(self.cache)))
self.cache[query.lower()] = results
class FullChatDataset(Dataset):
def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=256):
self.datasets = []
for name in dataset_names:
try:
dataset = load_dataset(name, split="train")
self.datasets.append(dataset)
except Exception as e:
print(f"Failed to load dataset {name}: {e}")
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.max_length = max_length
def __len__(self):
return sum(len(d) for d in self.datasets)
def __getitem__(self, idx):
for dataset in self.datasets:
if idx < len(dataset):
item = dataset[idx]
break
idx -= len(dataset)
if 'dialog' in item:
dialog = item['dialog']
elif 'messages' in item:
dialog = [msg['text'] for msg in item['messages']]
else:
dialog = [v for k, v in item.items() if isinstance(v, str)]
context = " [SEP] ".join(dialog[:-1])
response = dialog[-1]
inputs = self.tokenizer(
context,
text_pair=response,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors="pt"
)
return {
'input_ids': inputs['input_ids'].flatten(),
'attention_mask': inputs['attention_mask'].flatten(),
'labels': inputs['input_ids'].flatten()
}
class SimpleTransformerModel(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x, mask=None):
x = self.embedding(x)
x = self.pos_encoder(x)
x = self.transformer(x, mask)
return self.fc(x)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=500):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(1)]
class VoiceInterface:
def __init__(self):
self.recognizer = sr.Recognizer()
self.engine = pyttsx3.init()
def listen(self) -> Union[str, None]:
with sr.Microphone() as source:
print("Listening...")
audio = self.recognizer.listen(source)
try:
text = self.recognizer.recognize_google(audio)
print(f"You said: {text}")
return text
except Exception as e:
print(f"Error recognizing speech: {e}")
return None
def speak(self, text: str):
print(f"Bot: {text}")
self.engine.say(text)
self.engine.runAndWait()
class ChatBot:
def __init__(self):
self.dataset = FullChatDataset()
self.model = SimpleTransformerModel(len(self.dataset.tokenizer))
self.voice_interface = VoiceInterface()
self.web_searcher = WebSearchWrapper()
def train(self, epochs=3, lr=3e-4):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(self.model.parameters(), lr=lr)
dataloader = DataLoader(self.dataset, batch_size=8, shuffle=True)
for epoch in range(epochs):
self.model.train()
total_loss = 0
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
for batch in pbar:
inputs = batch['input_ids'].to(device)
masks = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
outputs = self.model(inputs, masks)
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.set_postfix({'loss': loss.item()})
print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}")
def generate_response(self, prompt: str, max_length: int = 100, use_web: bool = True) -> str:
device = next(self.model.parameters()).device
self.model.eval()
# Add web context if needed
if use_web and self._needs_web_search(prompt):
web_results = self.web_searcher.search(prompt)
if web_results:
prompt = f"Web context: {', '.join(web_results[:3])}. User question: {prompt}"
inputs = self.dataset.tokenizer(
prompt,
return_tensors="pt",
max_length=256,
truncation=True,
padding='max_length'
).to(device)
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7
)
response = self.dataset.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def _needs_web_search(self, text: str) -> bool:
"""Determine if a query needs web search"""
question_words = ['what', 'when', 'where', 'who', 'why', 'how', 'which', '?']
return any(word in text.lower() for word in question_words)