|
|
"""
|
|
|
Step 4: Test your trained multilingual model
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
from transformers import GPT2LMHeadModel
|
|
|
import sentencepiece as spm
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
class MultilingualModel:
|
|
|
def __init__(self, model_path="./checkpoints_tiny/final"):
|
|
|
print("="*60)
|
|
|
print("LOADING MULTILINGUAL MODEL")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
|
print(f"❌ Model not found at: {model_path}")
|
|
|
print("Available checkpoints:")
|
|
|
checkpoints = list(Path("./checkpoints_tiny").glob("checkpoint-*"))
|
|
|
checkpoints += list(Path("./checkpoints_tiny").glob("step*"))
|
|
|
checkpoints += list(Path("./checkpoints_tiny").glob("final"))
|
|
|
|
|
|
for cp in checkpoints:
|
|
|
if cp.is_dir():
|
|
|
print(f" - {cp}")
|
|
|
|
|
|
if checkpoints:
|
|
|
model_path = str(checkpoints[-1])
|
|
|
print(f"Using: {model_path}")
|
|
|
else:
|
|
|
raise FileNotFoundError("No checkpoints found!")
|
|
|
|
|
|
|
|
|
tokenizer_path = os.path.join(model_path, "tokenizer", "spiece.model")
|
|
|
if not os.path.exists(tokenizer_path):
|
|
|
tokenizer_path = "./final_corpus/multilingual_spm.model"
|
|
|
|
|
|
print(f"Loading tokenizer from: {tokenizer_path}")
|
|
|
self.tokenizer = spm.SentencePieceProcessor()
|
|
|
self.tokenizer.load(tokenizer_path)
|
|
|
|
|
|
|
|
|
print(f"Loading model from: {model_path}")
|
|
|
self.model = GPT2LMHeadModel.from_pretrained(model_path)
|
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
self.model.to(self.device)
|
|
|
self.model.eval()
|
|
|
|
|
|
print(f"✅ Model loaded on: {self.device}")
|
|
|
print(f" Parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.1f}M")
|
|
|
print("="*60)
|
|
|
|
|
|
def generate(self, prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95):
|
|
|
"""Generate text from prompt"""
|
|
|
|
|
|
if not any(prompt.startswith(tag) for tag in ['[EN]', '[HI]', '[PA]']):
|
|
|
|
|
|
if any(char in prompt for char in 'अआइईउऊएऐओऔकखगघचछजझटठडढणतथदधनपफबभमयरलवशषसह'):
|
|
|
prompt = f"[HI] {prompt}"
|
|
|
elif any(char in prompt for char in 'ਅਆਇਈਉਊਏਐਓਔਕਖਗਘਚਛਜਝਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹ'):
|
|
|
prompt = f"[PA] {prompt}"
|
|
|
else:
|
|
|
prompt = f"[EN] {prompt}"
|
|
|
|
|
|
|
|
|
input_ids = self.tokenizer.encode(prompt)
|
|
|
input_tensor = torch.tensor([input_ids], device=self.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output = self.model.generate(
|
|
|
input_ids=input_tensor,
|
|
|
max_length=max_length,
|
|
|
temperature=temperature,
|
|
|
do_sample=True,
|
|
|
top_k=top_k,
|
|
|
top_p=top_p,
|
|
|
pad_token_id=self.tokenizer.pad_id() if self.tokenizer.pad_id() > 0 else 0,
|
|
|
eos_token_id=self.tokenizer.eos_id() if self.tokenizer.eos_id() > 0 else 2,
|
|
|
repetition_penalty=1.1,
|
|
|
)
|
|
|
|
|
|
|
|
|
generated = self.tokenizer.decode(output[0].tolist())
|
|
|
|
|
|
|
|
|
if generated.startswith(prompt):
|
|
|
result = generated[len(prompt):].strip()
|
|
|
else:
|
|
|
result = generated
|
|
|
|
|
|
return result
|
|
|
|
|
|
def batch_generate(self, prompts, **kwargs):
|
|
|
"""Generate for multiple prompts"""
|
|
|
results = []
|
|
|
for prompt in prompts:
|
|
|
result = self.generate(prompt, **kwargs)
|
|
|
results.append(result)
|
|
|
return results
|
|
|
|
|
|
def calculate_perplexity(self, text):
|
|
|
"""Calculate perplexity of given text"""
|
|
|
input_ids = self.tokenizer.encode(text)
|
|
|
if len(input_ids) < 2:
|
|
|
return float('inf')
|
|
|
|
|
|
input_tensor = torch.tensor([input_ids], device=self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(input_ids=input_tensor, labels=input_tensor)
|
|
|
loss = outputs.loss
|
|
|
|
|
|
perplexity = torch.exp(loss).item()
|
|
|
return perplexity
|
|
|
|
|
|
def interactive_mode(self):
|
|
|
"""Interactive chat with model"""
|
|
|
print("\n" + "="*60)
|
|
|
print("INTERACTIVE MODE")
|
|
|
print("="*60)
|
|
|
print("Enter prompts in any language (add [EN], [HI], [PA] tags)")
|
|
|
print("Commands: /temp X, /len X, /quit, /help")
|
|
|
print("="*60)
|
|
|
|
|
|
temperature = 0.7
|
|
|
max_length = 100
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
user_input = input("\nYou: ").strip()
|
|
|
|
|
|
if not user_input:
|
|
|
continue
|
|
|
|
|
|
|
|
|
if user_input.startswith('/'):
|
|
|
if user_input == '/quit':
|
|
|
break
|
|
|
elif user_input == '/help':
|
|
|
print("Commands:")
|
|
|
print(" /temp X - Set temperature (0.1 to 2.0)")
|
|
|
print(" /len X - Set max length (20 to 500)")
|
|
|
print(" /quit - Exit")
|
|
|
print(" /help - Show this help")
|
|
|
continue
|
|
|
elif user_input.startswith('/temp'):
|
|
|
try:
|
|
|
temp = float(user_input.split()[1])
|
|
|
if 0.1 <= temp <= 2.0:
|
|
|
temperature = temp
|
|
|
print(f"Temperature set to {temperature}")
|
|
|
else:
|
|
|
print("Temperature must be between 0.1 and 2.0")
|
|
|
except:
|
|
|
print("Usage: /temp 0.7")
|
|
|
continue
|
|
|
elif user_input.startswith('/len'):
|
|
|
try:
|
|
|
length = int(user_input.split()[1])
|
|
|
if 20 <= length <= 500:
|
|
|
max_length = length
|
|
|
print(f"Max length set to {max_length}")
|
|
|
else:
|
|
|
print("Length must be between 20 and 500")
|
|
|
except:
|
|
|
print("Usage: /len 100")
|
|
|
continue
|
|
|
|
|
|
|
|
|
print("Model: ", end="", flush=True)
|
|
|
response = self.generate(user_input, max_length=max_length, temperature=temperature)
|
|
|
print(response)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\n\nExiting...")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
|
|
|
def run_tests():
|
|
|
"""Run comprehensive tests"""
|
|
|
print("\n" + "="*60)
|
|
|
print("COMPREHENSIVE MODEL TESTS")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
model = MultilingualModel()
|
|
|
|
|
|
|
|
|
test_suites = {
|
|
|
"English": [
|
|
|
"[EN] The weather today is",
|
|
|
"[EN] I want to learn",
|
|
|
"[EN] Artificial intelligence",
|
|
|
"[EN] The capital of India is",
|
|
|
"[EN] Once upon a time",
|
|
|
],
|
|
|
"Hindi": [
|
|
|
"[HI] आज का मौसम",
|
|
|
"[HI] मैं सीखना चाहता हूं",
|
|
|
"[HI] कृत्रिम बुद्धिमत्ता",
|
|
|
"[HI] भारत की राजधानी है",
|
|
|
"[HI] एक बार की बात है",
|
|
|
],
|
|
|
"Punjabi": [
|
|
|
"[PA] ਅੱਜ ਦਾ ਮੌਸਮ",
|
|
|
"[PA] ਮੈਂ ਸਿੱਖਣਾ ਚਾਹੁੰਦਾ ਹਾਂ",
|
|
|
"[PA] ਕ੍ਰਿਤਰਿਮ ਬੁੱਧੀ",
|
|
|
"[PA] ਭਾਰਤ ਦੀ ਰਾਜਧਾਨੀ ਹੈ",
|
|
|
"[PA] ਇੱਕ ਵਾਰ ਦੀ ਗੱਲ ਹੈ",
|
|
|
],
|
|
|
"Language Switching": [
|
|
|
"[EN] Hello [HI] नमस्ते",
|
|
|
"[HI] यह अच्छा है [EN] this is good",
|
|
|
"[PA] ਸਤਿ ਸ੍ਰੀ ਅਕਾਲ [EN] Hello everyone",
|
|
|
],
|
|
|
"Code Mixing": [
|
|
|
"Hello दुनिया",
|
|
|
"मेरा name है",
|
|
|
"Today मौसम is good",
|
|
|
]
|
|
|
}
|
|
|
|
|
|
for suite_name, prompts in test_suites.items():
|
|
|
print(f"\n{'='*40}")
|
|
|
print(f"{suite_name.upper()} TESTS")
|
|
|
print('='*40)
|
|
|
|
|
|
for i, prompt in enumerate(prompts):
|
|
|
print(f"\nTest {i+1}:")
|
|
|
print(f"Prompt: {prompt}")
|
|
|
|
|
|
|
|
|
response = model.generate(prompt, max_length=50, temperature=0.7)
|
|
|
print(f"Response: {response}")
|
|
|
|
|
|
|
|
|
try:
|
|
|
perplexity = model.calculate_perplexity(response)
|
|
|
print(f"Perplexity: {perplexity:.2f}")
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
def benchmark_model():
|
|
|
"""Benchmark model performance"""
|
|
|
print("\n" + "="*60)
|
|
|
print("MODEL BENCHMARK")
|
|
|
print("="*60)
|
|
|
|
|
|
model = MultilingualModel()
|
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
test_prompt = "[EN] The quick brown fox"
|
|
|
|
|
|
times = []
|
|
|
for _ in range(10):
|
|
|
start = time.time()
|
|
|
model.generate(test_prompt, max_length=50)
|
|
|
end = time.time()
|
|
|
times.append(end - start)
|
|
|
|
|
|
avg_time = sum(times) / len(times)
|
|
|
print(f"Average generation time (50 tokens): {avg_time:.3f}s")
|
|
|
print(f"Tokens per second: {50/avg_time:.1f}")
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
memory_allocated = torch.cuda.memory_allocated() / 1e9
|
|
|
memory_reserved = torch.cuda.memory_reserved() / 1e9
|
|
|
print(f"GPU Memory allocated: {memory_allocated:.2f} GB")
|
|
|
print(f"GPU Memory reserved: {memory_reserved:.2f} GB")
|
|
|
|
|
|
def create_web_interface():
|
|
|
"""Simple web interface for the model"""
|
|
|
html_code = """
|
|
|
<!DOCTYPE html>
|
|
|
<html>
|
|
|
<head>
|
|
|
<title>Multilingual LM Demo</title>
|
|
|
<style>
|
|
|
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
|
|
|
.container { display: flex; flex-direction: column; gap: 20px; }
|
|
|
textarea { width: 100%; height: 100px; padding: 10px; font-size: 16px; }
|
|
|
button { padding: 10px 20px; background: #4CAF50; color: white; border: none; cursor: pointer; }
|
|
|
button:hover { background: #45a049; }
|
|
|
.output { border: 1px solid #ccc; padding: 15px; min-height: 100px; background: #f9f9f9; }
|
|
|
.language-tag { display: inline-block; margin: 5px; padding: 5px 10px; background: #e0e0e0; cursor: pointer; }
|
|
|
</style>
|
|
|
</head>
|
|
|
<body>
|
|
|
<div class="container">
|
|
|
<h1>Multilingual Language Model Demo</h1>
|
|
|
|
|
|
<div>
|
|
|
<strong>Language:</strong>
|
|
|
<span class="language-tag" onclick="setLanguage('[EN] ')">English</span>
|
|
|
<span class="language-tag" onclick="setLanguage('[HI] ')">Hindi</span>
|
|
|
<span class="language-tag" onclick="setLanguage('[PA] ')">Punjabi</span>
|
|
|
</div>
|
|
|
|
|
|
<textarea id="prompt" placeholder="Enter your prompt here..."></textarea>
|
|
|
|
|
|
<div>
|
|
|
<label>Temperature: <input type="range" id="temp" min="0.1" max="2.0" step="0.1" value="0.7"></label>
|
|
|
<label>Max Length: <input type="number" id="maxlen" min="20" max="500" value="100"></label>
|
|
|
</div>
|
|
|
|
|
|
<button onclick="generate()">Generate</button>
|
|
|
|
|
|
<div class="output" id="output">Response will appear here...</div>
|
|
|
</div>
|
|
|
|
|
|
<script>
|
|
|
function setLanguage(tag) {
|
|
|
document.getElementById('prompt').value = tag;
|
|
|
}
|
|
|
|
|
|
async function generate() {
|
|
|
const prompt = document.getElementById('prompt').value;
|
|
|
const temp = document.getElementById('temp').value;
|
|
|
const maxlen = document.getElementById('maxlen').value;
|
|
|
|
|
|
document.getElementById('output').innerHTML = 'Generating...';
|
|
|
|
|
|
try {
|
|
|
const response = await fetch('/generate', {
|
|
|
method: 'POST',
|
|
|
headers: {'Content-Type': 'application/json'},
|
|
|
body: JSON.stringify({prompt, temp, maxlen})
|
|
|
});
|
|
|
|
|
|
const data = await response.json();
|
|
|
document.getElementById('output').innerHTML = data.response;
|
|
|
} catch (error) {
|
|
|
document.getElementById('output').innerHTML = 'Error: ' + error;
|
|
|
}
|
|
|
}
|
|
|
</script>
|
|
|
</body>
|
|
|
</html>
|
|
|
"""
|
|
|
|
|
|
|
|
|
with open("model_demo.html", "w", encoding="utf-8") as f:
|
|
|
f.write(html_code)
|
|
|
|
|
|
print("Web interface saved as model_demo.html")
|
|
|
print("To use it, you need a backend server (see create_server.py)")
|
|
|
|
|
|
def main():
|
|
|
"""Main function"""
|
|
|
print("\n" + "="*60)
|
|
|
print("MULTILINGUAL MODEL PLAYGROUND")
|
|
|
print("="*60)
|
|
|
print("\nOptions:")
|
|
|
print("1. Interactive chat")
|
|
|
print("2. Run comprehensive tests")
|
|
|
print("3. Benchmark model")
|
|
|
print("4. Create web interface")
|
|
|
print("5. Quick generation test")
|
|
|
print("6. Exit")
|
|
|
|
|
|
|
|
|
model = None
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
choice = input("\nSelect option (1-6): ").strip()
|
|
|
|
|
|
if choice == '1':
|
|
|
if model is None:
|
|
|
model = MultilingualModel()
|
|
|
model.interactive_mode()
|
|
|
|
|
|
elif choice == '2':
|
|
|
run_tests()
|
|
|
|
|
|
elif choice == '3':
|
|
|
benchmark_model()
|
|
|
|
|
|
elif choice == '4':
|
|
|
create_web_interface()
|
|
|
|
|
|
elif choice == '5':
|
|
|
if model is None:
|
|
|
model = MultilingualModel()
|
|
|
|
|
|
prompt = input("Enter prompt: ").strip()
|
|
|
if prompt:
|
|
|
response = model.generate(prompt)
|
|
|
print(f"\nResponse: {response}")
|
|
|
|
|
|
elif choice == '6':
|
|
|
print("Goodbye!")
|
|
|
break
|
|
|
|
|
|
else:
|
|
|
print("Invalid choice. Please enter 1-6.")
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\n\nExiting...")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |