Spaces:
Runtime error
Runtime error
CaesarCloudSync commited on
Commit ·
8ce55a3
0
Parent(s):
Caesar ShowCase Start
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +3 -0
- .gitignore +26 -0
- Caesar Analysis/caesaranalysis.ipynb +0 -0
- Caesar Analysis/spectrogram0.png +0 -0
- Caesar Analysis/spectrogram1.png +0 -0
- Caesar Analysis/spectrogram2.png +0 -0
- Caesar Analysis/test/spectrogram0.png +0 -0
- Caesar Analysis/test/spectrogram1.png +0 -0
- Caesar Analysis/test/spectrogram2.png +0 -0
- CaesarAI Logo.png +0 -0
- CaesarAIAPP/.expo/README.md +15 -0
- CaesarAIAPP/App copy.js +162 -0
- CaesarAIAPP/Recording-Audio-in-React-Native-Expo +1 -0
- CaesarAIAPP/backend/CaesarAIAPI +1 -0
- CaesarAIAPP/backend/temp.3pg +0 -0
- CaesarAIAPP/frontend +1 -0
- CaesarAIAPP/frontend-main/.gitignore +14 -0
- CaesarAIAPP/frontend-main/App.js +20 -0
- CaesarAIAPP/frontend-main/assets/adaptive-icon.png +0 -0
- CaesarAIAPP/frontend-main/assets/favicon.png +0 -0
- CaesarAIAPP/frontend-main/assets/icon.png +0 -0
- CaesarAIAPP/frontend-main/assets/splash.png +0 -0
- CaesarAIAPP/frontend-main/babel.config.js +6 -0
- CaesarAIAPP/translateWithWhisper +1 -0
- CaesarAIGPT/test.ipynb +647 -0
- CaesarAINL/Procfile +1 -0
- CaesarAINL/amari@172.20.10.197/caesarReminder.py +52 -0
- CaesarAINL/amari@172.20.10.197/caesarapis/caesarReminder.py +52 -0
- CaesarAINL/app.py +29 -0
- CaesarAINL/bert tutorial/caesarbert.py +476 -0
- CaesarAINL/bert tutorial/intent_classification_with_bert.ipynb +1239 -0
- CaesarAINL/caesar_tensorflow_install.md +12 -0
- CaesarAINL/caesarapis.py +26 -0
- CaesarAINL/caesarapis/caesarReminder.py +52 -0
- CaesarAINL/caesarbackground.md +5 -0
- CaesarAINL/caesarcomplete/berttest.py +51 -0
- CaesarAINL/caesarcomplete/caesar_tensorflow_install.md +12 -0
- CaesarAINL/caesarcomplete/caesarapis.py +17 -0
- CaesarAINL/caesarcomplete/caesarnlexamples.py +69 -0
- CaesarAINL/caesarcomplete/data_aggregation.ipynb +493 -0
- CaesarAINL/caesarinfer.py +69 -0
- CaesarAINL/caesarintro.mp3 +0 -0
- CaesarAINL/caesarnl.py +70 -0
- CaesarAINL/caesarnlexamples.py +69 -0
- CaesarAINL/caesarnlrasp.py +70 -0
- CaesarAINL/caesartrain.py +224 -0
- CaesarAINL/caesartrainperformance/.png +0 -0
- CaesarAINL/caesartrainperformance/history.png +0 -0
- CaesarAINL/data_aggregation.ipynb +277 -0
- CaesarAINL/runcaesarnl.bat +1 -0
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
| 3 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
node_modules
|
| 2 |
+
*.log
|
| 3 |
+
.next
|
| 4 |
+
app
|
| 5 |
+
notionsrenv
|
| 6 |
+
package-lock.json
|
| 7 |
+
Caesar Old
|
| 8 |
+
build
|
| 9 |
+
CaesarAudioWAVs
|
| 10 |
+
*.wav
|
| 11 |
+
Miscallenouse Translate
|
| 12 |
+
Miscallenouse
|
| 13 |
+
caesaraienv
|
| 14 |
+
caesaraineuralnetenv
|
| 15 |
+
caesarsummarize.exe
|
| 16 |
+
test.exe
|
| 17 |
+
caesarai
|
| 18 |
+
dist
|
| 19 |
+
glove.6B.300d_should_be_here.txt
|
| 20 |
+
glove.6B.300d.txt
|
| 21 |
+
*.csv
|
| 22 |
+
*.json
|
| 23 |
+
*.txt
|
| 24 |
+
*.xlsx
|
| 25 |
+
*.pyc
|
| 26 |
+
caesarmodel
|
Caesar Analysis/caesaranalysis.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Caesar Analysis/spectrogram0.png
ADDED
|
Caesar Analysis/spectrogram1.png
ADDED
|
Caesar Analysis/spectrogram2.png
ADDED
|
Caesar Analysis/test/spectrogram0.png
ADDED
|
Caesar Analysis/test/spectrogram1.png
ADDED
|
Caesar Analysis/test/spectrogram2.png
ADDED
|
CaesarAI Logo.png
ADDED
|
CaesarAIAPP/.expo/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
> Why do I have a folder named ".expo" in my project?
|
| 2 |
+
|
| 3 |
+
The ".expo" folder is created when an Expo project is started using "expo start" command.
|
| 4 |
+
|
| 5 |
+
> What do the files contain?
|
| 6 |
+
|
| 7 |
+
- "devices.json": contains information about devices that have recently opened this project. This is used to populate the "Development sessions" list in your development builds.
|
| 8 |
+
- "packager-info.json": contains port numbers and process PIDs that are used to serve the application to the mobile device/simulator.
|
| 9 |
+
- "settings.json": contains the server configuration that is used to serve the application manifest.
|
| 10 |
+
|
| 11 |
+
> Should I commit the ".expo" folder?
|
| 12 |
+
|
| 13 |
+
No, you should not share the ".expo" folder. It does not contain any information that is relevant for other developers working on the project, it is specific to your machine.
|
| 14 |
+
|
| 15 |
+
Upon project creation, the ".expo" folder is already added to your ".gitignore" file.
|
CaesarAIAPP/App copy.js
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import * as React from 'react';
|
| 2 |
+
import { Text, View, StyleSheet, Button,TouchableHighlight,ActivityIndicator,Alert,Image} from 'react-native';
|
| 3 |
+
import { Audio } from 'expo-av';
|
| 4 |
+
import axios from 'axios';
|
| 5 |
+
import { Feather } from '@expo/vector-icons';
|
| 6 |
+
import { EvilIcons } from '@expo/vector-icons';
|
| 7 |
+
import { FontAwesome } from '@expo/vector-icons';
|
| 8 |
+
|
| 9 |
+
export default function App() {
|
| 10 |
+
const [recording, setRecording] = React.useState();
|
| 11 |
+
const [recognition,setRecognition] = React.useState();
|
| 12 |
+
const [caesarson,setCaesarOn] = React.useState(false);
|
| 13 |
+
const [recognizing,setRecognizing] = React.useState(false);
|
| 14 |
+
async function turncaesaron(){
|
| 15 |
+
setRecognizing(true)
|
| 16 |
+
const response = await axios.get("https://palondomus-caesarai.hf.space")
|
| 17 |
+
|
| 18 |
+
if (response.data === "Welcome to CaesarAI's API's and CaesarAINL."){
|
| 19 |
+
console.log(response.data)
|
| 20 |
+
setCaesarOn(true)
|
| 21 |
+
setRecognizing(false)
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
}
|
| 25 |
+
const toCapitalize = (str) => {
|
| 26 |
+
return str.charAt(0).toUpperCase() + str.slice(1);
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
async function startRecording() {
|
| 31 |
+
try {
|
| 32 |
+
if (caesarson === true){
|
| 33 |
+
console.log('Requesting permissions..');
|
| 34 |
+
await Audio.requestPermissionsAsync();
|
| 35 |
+
await Audio.setAudioModeAsync({
|
| 36 |
+
allowsRecordingIOS: true,
|
| 37 |
+
playsInSilentModeIOS: true,
|
| 38 |
+
});
|
| 39 |
+
console.log('Starting recording..');
|
| 40 |
+
const recording = new Audio.Recording();
|
| 41 |
+
await recording.prepareToRecordAsync(Audio.RECORDING_OPTIONS_PRESET_HIGH_QUALITY);
|
| 42 |
+
await recording.startAsync();
|
| 43 |
+
setRecording(recording);
|
| 44 |
+
console.log('Recording started');}
|
| 45 |
+
else{
|
| 46 |
+
Alert.alert("Turn CaesarON")
|
| 47 |
+
}
|
| 48 |
+
} catch (err) {
|
| 49 |
+
console.error('Failed to start recording', err);
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
async function stopRecording() {
|
| 54 |
+
// utitlity function to convert BLOB to BASE64
|
| 55 |
+
const blobToBase64 = (blob) => {
|
| 56 |
+
const reader = new FileReader();
|
| 57 |
+
reader.readAsDataURL(blob);
|
| 58 |
+
return new Promise((resolve) => {
|
| 59 |
+
reader.onloadend = () => {
|
| 60 |
+
resolve(reader.result);
|
| 61 |
+
};
|
| 62 |
+
});
|
| 63 |
+
};
|
| 64 |
+
console.log('Stopping recording..');
|
| 65 |
+
setRecording(undefined);
|
| 66 |
+
await recording.stopAndUnloadAsync();
|
| 67 |
+
const uri = recording.getURI();
|
| 68 |
+
const blob = await new Promise((resolve, reject) => {
|
| 69 |
+
const xhr = new XMLHttpRequest();
|
| 70 |
+
xhr.onload = function () {
|
| 71 |
+
resolve(xhr.response);
|
| 72 |
+
};
|
| 73 |
+
xhr.onerror = function (e) {
|
| 74 |
+
reject(new TypeError("Network request failed"));
|
| 75 |
+
};
|
| 76 |
+
xhr.responseType = "blob";
|
| 77 |
+
xhr.open("GET", uri, true);
|
| 78 |
+
xhr.send(null);
|
| 79 |
+
});
|
| 80 |
+
const audioBase64 = await blobToBase64(blob);
|
| 81 |
+
setRecognizing(true)
|
| 82 |
+
const response = await axios.post("https://palondomus-caesarai.hf.space/caesarsr",{"audio_data":audioBase64})
|
| 83 |
+
//console.log(response.data)
|
| 84 |
+
setRecognition(response.data.message)
|
| 85 |
+
setRecognizing(false)
|
| 86 |
+
//console.log('Recording stopped and stored at', uri);
|
| 87 |
+
blob.close()
|
| 88 |
+
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return (
|
| 92 |
+
|
| 93 |
+
<View
|
| 94 |
+
style={[
|
| 95 |
+
styles.container,
|
| 96 |
+
{
|
| 97 |
+
// Try setting `flexDirection` to `"row"`.
|
| 98 |
+
flexDirection: 'column',
|
| 99 |
+
},
|
| 100 |
+
]}>
|
| 101 |
+
<View style={{position:"absolute",width:100,height:100}} >
|
| 102 |
+
<Image style={{position:"absolute",width:"100%",height:"100%",top:30}} source={require("./CaesarAILogo.png")}></Image>
|
| 103 |
+
</View>
|
| 104 |
+
<View style={{flex: 1}} >
|
| 105 |
+
|
| 106 |
+
<View style={{display:"flex",flexDirection:"row",justifyContent:"flex-end",alignItems:"flex-end"}}>
|
| 107 |
+
|
| 108 |
+
<TouchableHighlight style={{position:"relative",top:30,right:20}} onPress={turncaesaron}>
|
| 109 |
+
<Feather name="power" size={34} style={{color:caesarson === false ? "red":"blue"}}/>
|
| 110 |
+
</TouchableHighlight>
|
| 111 |
+
<TouchableHighlight style={{position:"relative",top:30,right:0}} >
|
| 112 |
+
<EvilIcons name="navicon" size={44} color="red" />
|
| 113 |
+
</TouchableHighlight>
|
| 114 |
+
</View>
|
| 115 |
+
</View>
|
| 116 |
+
|
| 117 |
+
<View style={{flex: 2}} >
|
| 118 |
+
<View style={{display:"flex",flexDirection:"row",justifyContent:"center",alignItems:"center"}}>
|
| 119 |
+
<View style={{position:"relative",top:30}}>
|
| 120 |
+
{ recognition !== undefined &&
|
| 121 |
+
<Text style={{fontSize:30}}>{toCapitalize(recognition+".")}</Text>
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
{recognizing === true && <ActivityIndicator style={{marginTop:30}} size={55} /> }
|
| 125 |
+
</View>
|
| 126 |
+
</View>
|
| 127 |
+
|
| 128 |
+
</View>
|
| 129 |
+
<View style={{flex: 3}} >
|
| 130 |
+
<View style={{display:"flex",flexDirection:"row",justifyContent:"center",alignItems:"center"}}>
|
| 131 |
+
<TouchableHighlight style={{position:"relative",top:70}} onPress={recording ? stopRecording : startRecording}>
|
| 132 |
+
<FontAwesome name="microphone" size={94} style={{color:recording ? 'blue' : 'black'}} />
|
| 133 |
+
</TouchableHighlight>
|
| 134 |
+
</View>
|
| 135 |
+
</View>
|
| 136 |
+
</View>
|
| 137 |
+
|
| 138 |
+
);
|
| 139 |
+
}
|
| 140 |
+
/* <Button
|
| 141 |
+
title={recording ? 'Stop Recording' : 'Start Recording'}
|
| 142 |
+
onPress={recording ? stopRecording : startRecording}
|
| 143 |
+
/> */
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
const styles = StyleSheet.create({
|
| 147 |
+
container: {
|
| 148 |
+
flex: 1,
|
| 149 |
+
padding: 20,
|
| 150 |
+
},
|
| 151 |
+
});
|
| 152 |
+
/*
|
| 153 |
+
<View style={{"flex":""}}>
|
| 154 |
+
<Text>{recognition}</Text>
|
| 155 |
+
<TouchableHighlight onPress={turncaesaron}>
|
| 156 |
+
<Feather name="power" size={24} style={{color:caesarson === false ? "red":"blue"}}/>
|
| 157 |
+
</TouchableHighlight>
|
| 158 |
+
<Button
|
| 159 |
+
title={recording ? 'Stop Recording' : 'Start Recording'}
|
| 160 |
+
onPress={recording ? stopRecording : startRecording}
|
| 161 |
+
/>
|
| 162 |
+
</View> */
|
CaesarAIAPP/Recording-Audio-in-React-Native-Expo
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 1cfc6b5ae2ab37cf217e0eed524894594d4ade37
|
CaesarAIAPP/backend/CaesarAIAPI
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 1f039900883be91f728313490cfe16e7165d8723
|
CaesarAIAPP/backend/temp.3pg
ADDED
|
Binary file (6.59 kB). View file
|
|
|
CaesarAIAPP/frontend
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 13ac3048649c9a7a8520a4612e62f780ec04d403
|
CaesarAIAPP/frontend-main/.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
node_modules/
|
| 2 |
+
.expo/
|
| 3 |
+
dist/
|
| 4 |
+
npm-debug.*
|
| 5 |
+
*.jks
|
| 6 |
+
*.p8
|
| 7 |
+
*.p12
|
| 8 |
+
*.key
|
| 9 |
+
*.mobileprovision
|
| 10 |
+
*.orig.*
|
| 11 |
+
web-build/
|
| 12 |
+
|
| 13 |
+
# macOS
|
| 14 |
+
.DS_Store
|
CaesarAIAPP/frontend-main/App.js
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { StatusBar } from 'expo-status-bar';
|
| 2 |
+
import { StyleSheet, Text, View } from 'react-native';
|
| 3 |
+
|
| 4 |
+
export default function App() {
|
| 5 |
+
return (
|
| 6 |
+
<View style={styles.container}>
|
| 7 |
+
<Text>Open up App.js to start working on your app!</Text>
|
| 8 |
+
<StatusBar style="auto" />
|
| 9 |
+
</View>
|
| 10 |
+
);
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
const styles = StyleSheet.create({
|
| 14 |
+
container: {
|
| 15 |
+
flex: 1,
|
| 16 |
+
backgroundColor: '#fff',
|
| 17 |
+
alignItems: 'center',
|
| 18 |
+
justifyContent: 'center',
|
| 19 |
+
},
|
| 20 |
+
});
|
CaesarAIAPP/frontend-main/assets/adaptive-icon.png
ADDED
|
|
CaesarAIAPP/frontend-main/assets/favicon.png
ADDED
|
|
CaesarAIAPP/frontend-main/assets/icon.png
ADDED
|
|
CaesarAIAPP/frontend-main/assets/splash.png
ADDED
|
CaesarAIAPP/frontend-main/babel.config.js
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
module.exports = function(api) {
|
| 2 |
+
api.cache(true);
|
| 3 |
+
return {
|
| 4 |
+
presets: ['babel-preset-expo'],
|
| 5 |
+
};
|
| 6 |
+
};
|
CaesarAIAPP/translateWithWhisper
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 03798a52483e9a322390df7c307f880e17c7e9c7
|
CaesarAIGPT/test.ipynb
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"## makemore: part 5"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": 1,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": [
|
| 16 |
+
"import torch\n",
|
| 17 |
+
"import torch.nn.functional as F\n",
|
| 18 |
+
"import matplotlib.pyplot as plt # for making figures\n",
|
| 19 |
+
"%matplotlib inline"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": 2,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [
|
| 27 |
+
{
|
| 28 |
+
"name": "stdout",
|
| 29 |
+
"output_type": "stream",
|
| 30 |
+
"text": [
|
| 31 |
+
"32033\n",
|
| 32 |
+
"15\n",
|
| 33 |
+
"['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']\n"
|
| 34 |
+
]
|
| 35 |
+
}
|
| 36 |
+
],
|
| 37 |
+
"source": [
|
| 38 |
+
"# read in all the words\n",
|
| 39 |
+
"words = open('names.txt', 'r').read().splitlines()\n",
|
| 40 |
+
"print(len(words))\n",
|
| 41 |
+
"print(max(len(w) for w in words))\n",
|
| 42 |
+
"print(words[:8])"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 3,
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [
|
| 50 |
+
{
|
| 51 |
+
"name": "stdout",
|
| 52 |
+
"output_type": "stream",
|
| 53 |
+
"text": [
|
| 54 |
+
"{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}\n",
|
| 55 |
+
"27\n"
|
| 56 |
+
]
|
| 57 |
+
}
|
| 58 |
+
],
|
| 59 |
+
"source": [
|
| 60 |
+
"# build the vocabulary of characters and mappings to/from integers\n",
|
| 61 |
+
"chars = sorted(list(set(''.join(words))))\n",
|
| 62 |
+
"stoi = {s:i+1 for i,s in enumerate(chars)}\n",
|
| 63 |
+
"stoi['.'] = 0\n",
|
| 64 |
+
"itos = {i:s for s,i in stoi.items()}\n",
|
| 65 |
+
"vocab_size = len(itos)\n",
|
| 66 |
+
"print(itos)\n",
|
| 67 |
+
"print(vocab_size)"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": 4,
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [],
|
| 75 |
+
"source": [
|
| 76 |
+
"# shuffle up the words\n",
|
| 77 |
+
"import random\n",
|
| 78 |
+
"random.seed(42)\n",
|
| 79 |
+
"random.shuffle(words)"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "code",
|
| 84 |
+
"execution_count": 5,
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [
|
| 87 |
+
{
|
| 88 |
+
"name": "stdout",
|
| 89 |
+
"output_type": "stream",
|
| 90 |
+
"text": [
|
| 91 |
+
"torch.Size([182625, 8]) torch.Size([182625])\n",
|
| 92 |
+
"torch.Size([22655, 8]) torch.Size([22655])\n",
|
| 93 |
+
"torch.Size([22866, 8]) torch.Size([22866])\n"
|
| 94 |
+
]
|
| 95 |
+
}
|
| 96 |
+
],
|
| 97 |
+
"source": [
|
| 98 |
+
"# build the dataset\n",
|
| 99 |
+
"block_size = 8 # context length: how many characters do we take to predict the next one?\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"def build_dataset(words): \n",
|
| 102 |
+
" X, Y = [], []\n",
|
| 103 |
+
" \n",
|
| 104 |
+
" for w in words:\n",
|
| 105 |
+
" context = [0] * block_size\n",
|
| 106 |
+
" for ch in w + '.':\n",
|
| 107 |
+
" ix = stoi[ch]\n",
|
| 108 |
+
" X.append(context)\n",
|
| 109 |
+
" Y.append(ix)\n",
|
| 110 |
+
" context = context[1:] + [ix] # crop and append\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" X = torch.tensor(X)\n",
|
| 113 |
+
" Y = torch.tensor(Y)\n",
|
| 114 |
+
" print(X.shape, Y.shape)\n",
|
| 115 |
+
" return X, Y\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"n1 = int(0.8*len(words))\n",
|
| 118 |
+
"n2 = int(0.9*len(words))\n",
|
| 119 |
+
"Xtr, Ytr = build_dataset(words[:n1]) # 80%\n",
|
| 120 |
+
"Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n",
|
| 121 |
+
"Xte, Yte = build_dataset(words[n2:]) # 10%"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": 6,
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [
|
| 129 |
+
{
|
| 130 |
+
"name": "stdout",
|
| 131 |
+
"output_type": "stream",
|
| 132 |
+
"text": [
|
| 133 |
+
"........ --> y\n",
|
| 134 |
+
".......y --> u\n",
|
| 135 |
+
"......yu --> h\n",
|
| 136 |
+
".....yuh --> e\n",
|
| 137 |
+
"....yuhe --> n\n",
|
| 138 |
+
"...yuhen --> g\n",
|
| 139 |
+
"..yuheng --> .\n",
|
| 140 |
+
"........ --> d\n",
|
| 141 |
+
".......d --> i\n",
|
| 142 |
+
"......di --> o\n",
|
| 143 |
+
".....dio --> n\n",
|
| 144 |
+
"....dion --> d\n",
|
| 145 |
+
"...diond --> r\n",
|
| 146 |
+
"..diondr --> e\n",
|
| 147 |
+
".diondre --> .\n",
|
| 148 |
+
"........ --> x\n",
|
| 149 |
+
".......x --> a\n",
|
| 150 |
+
"......xa --> v\n",
|
| 151 |
+
".....xav --> i\n",
|
| 152 |
+
"....xavi --> e\n"
|
| 153 |
+
]
|
| 154 |
+
}
|
| 155 |
+
],
|
| 156 |
+
"source": [
|
| 157 |
+
"for x,y in zip(Xtr[:20], Ytr[:20]):\n",
|
| 158 |
+
" print(''.join(itos[ix.item()] for ix in x), '-->', itos[y.item()])"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "code",
|
| 163 |
+
"execution_count": 7,
|
| 164 |
+
"metadata": {},
|
| 165 |
+
"outputs": [],
|
| 166 |
+
"source": [
|
| 167 |
+
"# Near copy paste of the layers we have developed in Part 3\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"# -----------------------------------------------------------------------------------------------\n",
|
| 170 |
+
"class Linear:\n",
|
| 171 |
+
" \n",
|
| 172 |
+
" def __init__(self, fan_in, fan_out, bias=True):\n",
|
| 173 |
+
" self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5 # note: kaiming init\n",
|
| 174 |
+
" self.bias = torch.zeros(fan_out) if bias else None\n",
|
| 175 |
+
" \n",
|
| 176 |
+
" def __call__(self, x):\n",
|
| 177 |
+
" self.out = x @ self.weight\n",
|
| 178 |
+
" if self.bias is not None:\n",
|
| 179 |
+
" self.out += self.bias\n",
|
| 180 |
+
" return self.out\n",
|
| 181 |
+
" \n",
|
| 182 |
+
" def parameters(self):\n",
|
| 183 |
+
" return [self.weight] + ([] if self.bias is None else [self.bias])\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"# -----------------------------------------------------------------------------------------------\n",
|
| 186 |
+
"class BatchNorm1d:\n",
|
| 187 |
+
" \n",
|
| 188 |
+
" def __init__(self, dim, eps=1e-5, momentum=0.1):\n",
|
| 189 |
+
" self.eps = eps\n",
|
| 190 |
+
" self.momentum = momentum\n",
|
| 191 |
+
" self.training = True\n",
|
| 192 |
+
" # parameters (trained with backprop)\n",
|
| 193 |
+
" self.gamma = torch.ones(dim)\n",
|
| 194 |
+
" self.beta = torch.zeros(dim)\n",
|
| 195 |
+
" # buffers (trained with a running 'momentum update')\n",
|
| 196 |
+
" self.running_mean = torch.zeros(dim)\n",
|
| 197 |
+
" self.running_var = torch.ones(dim)\n",
|
| 198 |
+
" \n",
|
| 199 |
+
" def __call__(self, x):\n",
|
| 200 |
+
" # calculate the forward pass\n",
|
| 201 |
+
" if self.training:\n",
|
| 202 |
+
" if x.ndim == 2:\n",
|
| 203 |
+
" dim = 0\n",
|
| 204 |
+
" elif x.ndim == 3:\n",
|
| 205 |
+
" dim = (0,1)\n",
|
| 206 |
+
" xmean = x.mean(dim, keepdim=True) # batch mean\n",
|
| 207 |
+
" xvar = x.var(dim, keepdim=True) # batch variance\n",
|
| 208 |
+
" else:\n",
|
| 209 |
+
" xmean = self.running_mean\n",
|
| 210 |
+
" xvar = self.running_var\n",
|
| 211 |
+
" xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance\n",
|
| 212 |
+
" self.out = self.gamma * xhat + self.beta\n",
|
| 213 |
+
" # update the buffers\n",
|
| 214 |
+
" if self.training:\n",
|
| 215 |
+
" with torch.no_grad():\n",
|
| 216 |
+
" self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean\n",
|
| 217 |
+
" self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar\n",
|
| 218 |
+
" return self.out\n",
|
| 219 |
+
" \n",
|
| 220 |
+
" def parameters(self):\n",
|
| 221 |
+
" return [self.gamma, self.beta]\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"# -----------------------------------------------------------------------------------------------\n",
|
| 224 |
+
"class Tanh:\n",
|
| 225 |
+
" def __call__(self, x):\n",
|
| 226 |
+
" self.out = torch.tanh(x)\n",
|
| 227 |
+
" return self.out\n",
|
| 228 |
+
" def parameters(self):\n",
|
| 229 |
+
" return []\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"# -----------------------------------------------------------------------------------------------\n",
|
| 232 |
+
"class Embedding:\n",
|
| 233 |
+
" \n",
|
| 234 |
+
" def __init__(self, num_embeddings, embedding_dim):\n",
|
| 235 |
+
" self.weight = torch.randn((num_embeddings, embedding_dim))\n",
|
| 236 |
+
" \n",
|
| 237 |
+
" def __call__(self, IX):\n",
|
| 238 |
+
" self.out = self.weight[IX]\n",
|
| 239 |
+
" return self.out\n",
|
| 240 |
+
" \n",
|
| 241 |
+
" def parameters(self):\n",
|
| 242 |
+
" return [self.weight]\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"# -----------------------------------------------------------------------------------------------\n",
|
| 245 |
+
"class FlattenConsecutive:\n",
|
| 246 |
+
" \n",
|
| 247 |
+
" def __init__(self, n):\n",
|
| 248 |
+
" self.n = n\n",
|
| 249 |
+
" \n",
|
| 250 |
+
" def __call__(self, x):\n",
|
| 251 |
+
" B, T, C = x.shape\n",
|
| 252 |
+
" x = x.view(B, T//self.n, C*self.n)\n",
|
| 253 |
+
" if x.shape[1] == 1:\n",
|
| 254 |
+
" x = x.squeeze(1)\n",
|
| 255 |
+
" self.out = x\n",
|
| 256 |
+
" return self.out\n",
|
| 257 |
+
" \n",
|
| 258 |
+
" def parameters(self):\n",
|
| 259 |
+
" return []\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"# -----------------------------------------------------------------------------------------------\n",
|
| 262 |
+
"class Sequential:\n",
|
| 263 |
+
" \n",
|
| 264 |
+
" def __init__(self, layers):\n",
|
| 265 |
+
" self.layers = layers\n",
|
| 266 |
+
" \n",
|
| 267 |
+
" def __call__(self, x):\n",
|
| 268 |
+
" for layer in self.layers:\n",
|
| 269 |
+
" x = layer(x)\n",
|
| 270 |
+
" self.out = x\n",
|
| 271 |
+
" return self.out\n",
|
| 272 |
+
" \n",
|
| 273 |
+
" def parameters(self):\n",
|
| 274 |
+
" # get parameters of all layers and stretch them out into one list\n",
|
| 275 |
+
" return [p for layer in self.layers for p in layer.parameters()]\n"
|
| 276 |
+
]
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"cell_type": "code",
|
| 280 |
+
"execution_count": 8,
|
| 281 |
+
"metadata": {},
|
| 282 |
+
"outputs": [],
|
| 283 |
+
"source": [
|
| 284 |
+
"torch.manual_seed(42); # seed rng for reproducibility"
|
| 285 |
+
]
|
| 286 |
+
},
|
| 287 |
+
{
|
| 288 |
+
"cell_type": "code",
|
| 289 |
+
"execution_count": 9,
|
| 290 |
+
"metadata": {},
|
| 291 |
+
"outputs": [
|
| 292 |
+
{
|
| 293 |
+
"name": "stdout",
|
| 294 |
+
"output_type": "stream",
|
| 295 |
+
"text": [
|
| 296 |
+
"76579\n"
|
| 297 |
+
]
|
| 298 |
+
}
|
| 299 |
+
],
|
| 300 |
+
"source": [
|
| 301 |
+
"# original network\n",
|
| 302 |
+
"# n_embd = 10 # the dimensionality of the character embedding vectors\n",
|
| 303 |
+
"# n_hidden = 300 # the number of neurons in the hidden layer of the MLP\n",
|
| 304 |
+
"# model = Sequential([\n",
|
| 305 |
+
"# Embedding(vocab_size, n_embd),\n",
|
| 306 |
+
"# FlattenConsecutive(8), Linear(n_embd * 8, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),\n",
|
| 307 |
+
"# Linear(n_hidden, vocab_size),\n",
|
| 308 |
+
"# ])\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"# hierarchical network\n",
|
| 311 |
+
"n_embd = 24 # the dimensionality of the character embedding vectors\n",
|
| 312 |
+
"n_hidden = 128 # the number of neurons in the hidden layer of the MLP\n",
|
| 313 |
+
"model = Sequential([\n",
|
| 314 |
+
" Embedding(vocab_size, n_embd),\n",
|
| 315 |
+
" FlattenConsecutive(2), Linear(n_embd * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),\n",
|
| 316 |
+
" FlattenConsecutive(2), Linear(n_hidden*2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),\n",
|
| 317 |
+
" FlattenConsecutive(2), Linear(n_hidden*2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),\n",
|
| 318 |
+
" Linear(n_hidden, vocab_size),\n",
|
| 319 |
+
"])\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"# parameter init\n",
|
| 322 |
+
"with torch.no_grad():\n",
|
| 323 |
+
" model.layers[-1].weight *= 0.1 # last layer make less confident\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"parameters = model.parameters()\n",
|
| 326 |
+
"print(sum(p.nelement() for p in parameters)) # number of parameters in total\n",
|
| 327 |
+
"for p in parameters:\n",
|
| 328 |
+
" p.requires_grad = True"
|
| 329 |
+
]
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"cell_type": "code",
|
| 333 |
+
"execution_count": 10,
|
| 334 |
+
"metadata": {},
|
| 335 |
+
"outputs": [
|
| 336 |
+
{
|
| 337 |
+
"name": "stdout",
|
| 338 |
+
"output_type": "stream",
|
| 339 |
+
"text": [
|
| 340 |
+
" 0/ 200000: 3.3167\n",
|
| 341 |
+
" 10000/ 200000: 2.0576\n",
|
| 342 |
+
" 20000/ 200000: 2.0723\n",
|
| 343 |
+
" 30000/ 200000: 2.5134\n",
|
| 344 |
+
" 40000/ 200000: 2.1476\n",
|
| 345 |
+
" 50000/ 200000: 1.7836\n",
|
| 346 |
+
" 60000/ 200000: 2.2592\n",
|
| 347 |
+
" 70000/ 200000: 1.9331\n",
|
| 348 |
+
" 80000/ 200000: 1.6875\n",
|
| 349 |
+
" 90000/ 200000: 2.0395\n",
|
| 350 |
+
" 100000/ 200000: 1.7736\n",
|
| 351 |
+
" 110000/ 200000: 1.9570\n",
|
| 352 |
+
" 120000/ 200000: 1.7465\n",
|
| 353 |
+
" 130000/ 200000: 1.8126\n",
|
| 354 |
+
" 140000/ 200000: 1.7406\n",
|
| 355 |
+
" 150000/ 200000: 1.7466\n",
|
| 356 |
+
" 160000/ 200000: 1.8806\n",
|
| 357 |
+
" 170000/ 200000: 1.6266\n",
|
| 358 |
+
" 180000/ 200000: 1.6476\n",
|
| 359 |
+
" 190000/ 200000: 1.8555\n"
|
| 360 |
+
]
|
| 361 |
+
}
|
| 362 |
+
],
|
| 363 |
+
"source": [
|
| 364 |
+
"# same optimization as last time\n",
|
| 365 |
+
"max_steps = 200000\n",
|
| 366 |
+
"batch_size = 32\n",
|
| 367 |
+
"lossi = []\n",
|
| 368 |
+
"\n",
|
| 369 |
+
"for i in range(max_steps):\n",
|
| 370 |
+
" \n",
|
| 371 |
+
" # minibatch construct\n",
|
| 372 |
+
" ix = torch.randint(0, Xtr.shape[0], (batch_size,))\n",
|
| 373 |
+
" Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y\n",
|
| 374 |
+
" \n",
|
| 375 |
+
" # forward pass\n",
|
| 376 |
+
" logits = model(Xb)\n",
|
| 377 |
+
" loss = F.cross_entropy(logits, Yb) # loss function\n",
|
| 378 |
+
" \n",
|
| 379 |
+
" # backward pass\n",
|
| 380 |
+
" for p in parameters:\n",
|
| 381 |
+
" p.grad = None\n",
|
| 382 |
+
" loss.backward()\n",
|
| 383 |
+
" \n",
|
| 384 |
+
" # update: simple SGD\n",
|
| 385 |
+
" lr = 0.1 if i < 150000 else 0.01 # step learning rate decay\n",
|
| 386 |
+
" for p in parameters:\n",
|
| 387 |
+
" p.data += -lr * p.grad\n",
|
| 388 |
+
"\n",
|
| 389 |
+
" # track stats\n",
|
| 390 |
+
" if i % 10000 == 0: # print every once in a while\n",
|
| 391 |
+
" print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')\n",
|
| 392 |
+
" lossi.append(loss.log10().item())\n"
|
| 393 |
+
]
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"cell_type": "code",
|
| 397 |
+
"execution_count": 11,
|
| 398 |
+
"metadata": {},
|
| 399 |
+
"outputs": [
|
| 400 |
+
{
|
| 401 |
+
"data": {
|
| 402 |
+
"text/plain": [
|
| 403 |
+
"[<matplotlib.lines.Line2D at 0x7fb5a03e3b50>]"
|
| 404 |
+
]
|
| 405 |
+
},
|
| 406 |
+
"execution_count": 11,
|
| 407 |
+
"metadata": {},
|
| 408 |
+
"output_type": "execute_result"
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"data": {
|
| 412 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAA0n0lEQVR4nO3deXxU1f3/8ddnJnvInhBCFpIAASGEgGFVcMEF3IBqcau2amvtV1trF6u168/altbWblZqXWqrdce6L4gLIGuAsIYlLNnJvu+TOb8/ZggTSGDQJBMmn+fjwcOZM/fOfO7N+M7JufeeK8YYlFJKeS+LpwtQSinVvzTolVLKy2nQK6WUl9OgV0opL6dBr5RSXs7H0wX0JDo62iQnJ3u6DKWUOmNs3ry50hgT09NrgzLok5OTyc7O9nQZSil1xhCR/N5e06EbpZTychr0Sinl5TTolVLKy2nQK6WUl3Mr6EVkvojsFZE8Ebmvh9cXish2EckRkWwROdfltXtEZJeI7BSR50UkoC83QCml1MmdMuhFxAo8CiwAJgDXi8iE4xZbCUw2xmQCtwJPONeNB74DZBlj0gErcF2fVa+UUuqU3OnRTwfyjDEHjTHtwAvAQtcFjDGN5tg0mMGA65SYPkCgiPgAQUDJFy9bKaWUu9wJ+nig0OV5kbOtGxFZLCJ7gLdx9OoxxhQDDwMFQClQZ4z5oKcPEZHbncM+2RUVFae3FU5/WbmfT/d9vnWVUspbuRP00kPbCZPYG2NeM8aMBxYBDwKISASO3n8KMBIIFpGv9PQhxpjHjTFZxpismJgeL+46pX98eoBVGvRKKdWNO0FfBCS6PE/gJMMvxphVwGgRiQYuAg4ZYyqMMR3AcmD2F6j3pAL9fGhut/XX2yul1BnJnaDfBIwVkRQR8cNxMPUN1wVEZIyIiPPxVMAPqMIxZDNTRIKcr88DcvtyA1wF+Vlpbu/sr7dXSqkz0innujHG2ETkLuB9HGfNPGWM2SUidzhfXwZcDdwsIh1AC3Ct8+DsBhF5BdgC2ICtwOP9syka9Eop1RO3JjUzxrwDvHNc2zKXx0uBpb2s+3Pg51+gRrcF+llp0aBXSqluvOrKWEePXsfolVLKlVcFfaCvjw7dKKXUcbwq6IP8rLR0aNArpZQrrwt67dErpVR3XhX0ejBWKaVO5FVBH+y8YOrYtDtKKaW8KugD/azYDbTZ7J4uRSmlBg2vCvogPyuADt8opZQLrwz6Zj3zRimlunhV0Af6OS70bdGLppRSqotXBX2Qr7NHr0M3SinVxbuC3k+DXimljudVQR/YFfQ6dKOUUkd5VdAHOcfotUevlFLHeFnQ69CNUkodz6uCPlDPo1dKqRN4VdBrj14ppU7kVUEf4HO0R68HY5VS6iivCnqLRQj01amKlVLKlVcFPUCwv1WnQFBKKRdeF/Q6J71SSnXnVtCLyHwR2SsieSJyXw+vLxSR7SKSIyLZInKuy2vhIvKKiOwRkVwRmdWXG3C8IF8fvWBKKaVc+JxqARGxAo8CFwNFwCYRecMYs9tlsZXAG8YYIyIZwEvAeOdrfwbeM8ZcIyJ+QFCfbsFxAvV2gkop1Y07PfrpQJ4x5qAxph14AVjouoAxptEcu61TMGAARCQUmAs86Vyu3RhT20e19yhIh26UUqobd4I+Hih0eV7kbOtGRBaLyB7gbeBWZ3MqUAE8LSJbReQJEQnu6UNE5HbnsE92RUXFaW2EK71BuFJKdedO0EsPbSfclNUY85oxZjywCHjQ2ewDTAUeM8ZMAZqAE8b4nes/bozJMsZkxcTEuFN7jwL9fGjRs26UUqqLO0FfBCS6PE8ASnpb2BizChgtItHOdYuMMRucL7+CI/j7TZCvVQ/GKqWUC3eCfhMwVkRSnAdTrwPecF1ARMaIiDgfTwX8gCpjzBGgUETGORedB7gexO1zgX5Wmtu0R6+UUked8qwbY4xNRO4C3geswFPGmF0icofz9WXA1cDNItIBtADXuhyc/TbwnPOXxEHgln7Yji5Bfo4LpowxOH/3KKXUkHbKoAcwxrwDvHNc2zKXx0uBpb2smwNkff4ST0+Qn5VOu6G9046/c+4bpZQayrzuythgf8fvrsZWHadXSinwwqCPCwsEoKS21cOVKKXU4OB1QZ8U6bjwtrCm2cOVKKXU4OB1QZ8Y6ejRF1Rr0CulFHhh0IcE+BIR5KtBr5RSTl4X9OAYvinUoFdKKcBLgz5Bg14ppbp4ZdAnRQZRVNNCp/2EKXmUUmrI8dqgt9kNpXUtni5FKaU8ziuDPjHCeYpltQa9Ukp5ZdB3nUuv4/RKKeWdQR8XHoDVInqKpVJK4aVB72u1kBQZxL6yBk+XopRSHueVQQ+QkRDGtqJaT5ehlFIe57VBn5kYTll9G0fqdHIzpdTQ5rVBPzkxHICcwlqP1qGUUp7mtUE/IS4UX6to0CulhjyvDfoAXysT4kLZpkGvlBrivDbowTF8s72oVqdCUEoNaV4d9FnJkTS1d7Imr9LTpSillMd4ddBfOjGWkWEB/GXlfozRXr1Samjy6qD397HyrQvGsDm/hs/yqjxdjlJKeYRbQS8i80Vkr4jkich9Pby+UES2i0iOiGSLyLnHvW4Vka0i8lZfFe6uJVkJDA/x59/rDg/0Ryul1KBwyqAXESvwKLAAmABcLyITjltsJTDZGJMJ3Ao8cdzrdwO5X7jaz8Hfx8qcsTFk59fo8I1Sakhyp0c/Hcgzxhw0xrQDLwALXRcwxjSaYykaDHQlqogkAJdzYvgPmGnJEVQ3tXOwsslTJSillMe4E/TxQKHL8yJnWzcislhE9gBv4+jVH/Un4F7AfrIPEZHbncM+2RUVFW6U5b6s5AgANh+u6dP3VUqpM4E7QS89tJ0wBmKMec0YMx5YBDwIICJXAOXGmM2n+hBjzOPGmCxjTFZMTIwbZbkvNXoY4UG+ZOdX9+n7KqXUmcCdoC8CEl2eJwAlvS1sjFkFjBaRaOAc4CoROYxjyOdCEXn285f7+VgsQtaoCLK1R6+UGoLcCfpNwFgRSRERP+A64A3XBURkjIiI8/FUwA+oMsbcb4xJMMYkO9f7yBjzlT7dAjedPSqSg5VNVDW2eeLjlVLKY04Z9MYYG3AX8D6OM2deMsbsEpE7ROQO52JXAztFJAfHGTrXmkF2ikumczbLXSX1ni1EKaUGmI87Cxlj3gHeOa5tmcvjpcDSU7zHJ8Anp11hH0mLHQbAvrIG5qb17TEApZQazLz6ylhXUcP8iQr2Y39Zo6dLUUqpATVkgh5gbOww9pXrfWSVUkPLkAr6tNgQ8soa9QpZpdSQMqSCfmxsCA1tNkr1PrJKqSFkSAV92vBjB2SVUmqoGFpBHxsCoAdklVJDypAK+ohgP6KH+ZOdX41dby+olBoihlTQA8xPj+X9XWUs+cc6Gttsni5HKaX63ZAL+gcXpvPgwolk59fw0Z5yT5ejlFL9bsgFvYhw3fQk/Hws7Ciq9XQ5SinV74Zc0AP4Wi1MiAtle1Gdp0tRSql+NySDHiAjIYydxXV06kFZpZSXG7JBPyk+jKb2Tg5V6qmWSinvNmSDfrJz2mIdvlFKebshG/SjY4YR6GvVoFdKeb0hG/RWizA5MYzXc4pZsbvM0+UopVS/GbJBD/CrRemMCAvkG//OZtW+Ck+Xo5RS/WJIB/2Y4SH8787ZRAb78WJ2oafLUUqpfjGkgx7A38fKlRlxfLi7jPrWDk+Xo5RSfW7IBz3AoinxtNnsvLfziKdLUUqpPqdBD2QmhpMcFcQ/Pj3A5vwaT5ejlFJ9yq2gF5H5IrJXRPJE5L4eXl8oIttFJEdEskXkXGd7ooh8LCK5IrJLRO7u6w3oCyLCT6+YQF1LB1c/tpZ3dpR6uiSllOozpwx6EbECjwILgAnA9SIy4bjFVgKTjTGZwK3AE852G/B9Y8xZwEzgzh7WHRTmnRXLqnsvYMzwYfztozy9r6xSymu406OfDuQZYw4aY9qBF4CFrgsYYxrNsWQMBoyzvdQYs8X5uAHIBeL7qvi+FuTnw+1zUtldWs/aA1UANLR2UFjd7OHKlFLq83Mn6OMB13MPi+ghrEVksYjsAd7G0as//vVkYAqwoacPEZHbncM+2RUVnjunfeGUkUQP8+c37+by0qZCLnlkFZf9eTVttk6P1aSUUl+EO0EvPbSdMK5hjHnNGDMeWAQ82O0NRIYBrwLfNcbU9/QhxpjHjTFZxpismJgYN8rqH/4+Vn5y+VnkVzZz76vbaWy10dBm06kSlFJnLB83likCEl2eJwAlvS1sjFklIqNFJNoYUykivjhC/jljzPIvVu7AWDQlnvnpI9heVMfI8ADOXfoxGw5WkRIdzLbCWuadFevpEpVSym3uBP0mYKyIpADFwHXADa4LiMgY4IAxxojIVMAPqBIRAZ4Eco0xf+zb0vtXgK+V6SmRAIwfEcKGQ9XsLWvkre0l5PzsEsICfT1coVJKueeUQW+MsYnIXcD7gBV4yhizS0TucL6+DLgauFlEOoAW4Fpn6J8L3ATsEJEc51v+2BjzTj9sS7+ZkRLJC5sKsdkNxkBeeQNnj4r0dFlKKeUWd3r0OIP5nePalrk8Xgos7WG9NfQ8xn9GmZ4SxTPr8rue7ytr1KBXSp0x9MpYN8xIjUQEFmaOJNDXyr6yBk+XpJRSbnOrRz/URQ/z55lbpjMpPoxDlU3sL9PbDyqlzhzao3fT3LQYIoL9GDs8RHv0Sqkzigb9aUqLHUZ5Qxub82v444p9dHTaPV2SUkqdlA7dnKa02BAAbnl6I/WtNgC+d3GaJ0tSSqmT0h79aRobOwyA+lYbGQlhPPpxHpvzq2nt6OSXb+7i9ZxiD1eolFLdaY/+NMWHBxIb6s95aTE8cPkELvvzaq57fD3JUcHsL29k4shQFmYO2nnblFJDkPboT5OIsPL75/PbL2UQFujLG3edw4L0OAprmpmREkluaT0NrR3c9d8tXP6X1azNq/R0yUqpIU6D/nMY5u+DxeK4DixqmD9/uX4KO35xKd++cCx2Ax/mlvHOjlL2lTVwwxMbWH+wysMVK6WGMg36PuJrtTAlKRyrRXj4/X3YDbxw+0wignz512eHPV2eUmoI06DvQ8H+PkwcGUpxbQtJkUFMTYpgSVYiK3LLOFLX6unylFJDlAZ9H8tyzoGzIH0EIsINM5LotBte2FTg4cqUUkOVBn0fm5MWjQhckTESgFFRwZwzJoo3t/U6hb9SSvUrDfo+dsG44ay7bx6TEsK62s4ZE82BiiaqGtt4dn0+dz63hfyqJg9WqZQaSjTo+8GIsIBuz2c4b2Cy6XA1j36cx9s7SrnkkVXkFNZ6oDql1FCjQT8AJsWH4+9j4ck1hyita+WHl47DIsJrW4o8XZpSagjQK2MHgJ+PhczEcDYcqsZqEW6ckcTm/Bo+3VcBgN1uus7LV0qpvqY9+gFydPhmZmok4UF+zB0bzeGqZtYeqGT2bz/ipU2FHq5QKeWtNOgHyPSUKAAunTgCgPPGDQfgW89u4Uh9K795N5fc0noW//0zXtvqGNKxddoxxnimYKWU19ChmwEye3QUf7o2kwWTHEGfHBVEUmQQBdXNXJERx9s7Srnqb2vo6DTsL2skPMiPH768jW/MSeWb5432cPVKqTOZ9ugHiMUiLJoSj7+PFXBMjnZ5Rhyp0cH8/prJXDctCV+rhT9dm0lHp51bnt5EZWM77+064uHKlVJnOreCXkTmi8heEckTkft6eH2hiGwXkRwRyRaRc91ddyi799JxrPjeeQT6WXloUTrr7p/HoinxPHD5WaREB3N5Rhw7iupobred8r1K61oGoGKl1JnolEEvIlbgUWABMAG4XkQmHLfYSmCyMSYTuBV44jTWHbJEBKvzbBuLRQgL9AXg5lnJfPyD81mSlYjNbth4qJor/7qGpe/t6bb+0fH71fsrmPWbj1izX6dEVkqdyJ0e/XQgzxhz0BjTDrwALHRdwBjTaI4dNQwGjLvrqt6dPSoCq0V46O1cdhTX8eSaQ5TVOyZHW72/gvN+/wl7jzTw7Pp8AJ7X+XSUUj1wJ+jjAddz/4qcbd2IyGIR2QO8jaNX7/a6zvVvdw77ZFdUVLhTu9cb5u9DenwY+8sbSYoMotNu+MenB2lut3HfqzsoqG7m/uXbWZlbTqCvlRW7yqhr7vB02UqpQcadoO/pSp4TzvkzxrxmjBkPLAIePJ11nes/bozJMsZkxcTEuFHW0HD0/PsfXjqOxVPieXZ9Pl/6+1qKa1tYmDmSLQW12OyG33xpEu2ddt7Y3n3ytM35NXR02j1RulJqkHAn6IuARJfnCUCvUzEaY1YBo0Uk+nTXVSe6aeYovndxGpdNiuNH88dzVeZI2mx2vjk3ld9fM5mU6GCmJ0eyMHMk40eE8N8NBV1j97ml9Vz92FqW61QLSg1p7pxHvwkYKyIpQDFwHXCD6wIiMgY4YIwxIjIV8AOqgNpTratOLjEyiO/MGwtATIg/D395crfXl39rNhaLICJ8Y04q3395Gx/sLuPSiSO6pljYnF/DtdOSBrx2pdTgcMoevTHGBtwFvA/kAi8ZY3aJyB0icodzsauBnSKSg+Msm2uNQ4/r9sN2DFkRwX5dZ+sszBxJSnQwj6zYh91uus7C2VZY17X8q5uLuPVfm/SKW6WGELeujDXGvAO8c1zbMpfHS4Gl7q6r+oeP1cJ35o3hnhe38cKmQjYeribA18K+8gYa22wE+1l59JM8DlY0Ud7QRmzosemU61o62HiomosnxHpwC5RS/UGvjPUyV02OZ0pSOD99fSftNjvXTUvCGNhRVEdOYS0HKxw3PNlRVNdtvec25PONf2dzqFJviKKUt9Gg9zJWi7D06gwsAn5WC7fPTQVgW1Etr24pwt/HggjsLOke9LtL6gHYdKh6wGtWSvUvndTMC6XFhvDLq9Ipb2hlZHggo6KCeD2nhKLqZhakj2BHcR07i+u7rbP3SAMAGw9Xs2RaYk9vq5Q6Q2mP3kvdMCOJ716UBsDUpAhyS+sZHurPnReMIT0+jJ3Fx3r0bbZODjqHbDb20qPfdLiaLQU1/V+4UqrPaY9+CLj/svEsyUpkRkokFoswKT6M13NKqGhoIybEn7zyRjrthszEcHIKaymrbyV6mH/XPDwA9y/fQV1LB6vvvYAAX6sHt0Ypdbq0Rz8EDA8JYNboqK7bFU4cGQbA+7uOUFbfyp5Sx7DNTTNHAfDVpzYy+ZcfdM2I2drRycGKRioa2nh+YwE7i+vYnF+DMYblW4r07lhKDXLaox+CJsaHYrUIP/nfTv7fW7uZPToKPx8Ll2fE8fM3drHf2cP/cHcZN81KJq+8EbuBID8rD7+/l+aOToyB8SNC2HOkAT+rhYsnxBIR7Nfj5328p5x/rzvM4zdn4WvVvoVSA03/rxuCQgN8+e/XZ/C3G6YwzN+HT/ZWkBY7jABfK/+6ZRrv3j2H5KggVu4pBxxTKQD8/MoJdNgNX5kxirvnjeVIfSvXT090zLGzrfeZLZ5Zd5iP91awap9OVqeUJ2iPfoiakeq4h63dwHee38q42FAAspIdk6hdMH44/91QQEt7J3uONBDga+GasxO5emoCPs5e+XcvGouIsK2wjpc3FyICDa027rxgTNfnNLbZWJtXBcDyLcXMO0svyFJqoGnQD3FXZsRRXt/KTGfwH3Xh+OE8/dlh1h6oZM+ResbFhjgPzh47QCviePzlrAR++eZudhbvwmoRbj0nhUA/xwHbT/dW0N5pZ1J8GCtyy6hsbMNuDMNDAlBKDQwduhniRISvz0klPT6sW/v0lEiC/ay8trWY3NIGxo8I7fU9Fk+JZ1pyBFdkxNFpN+wsqcPWaaektoX3dx0hMtiP/7dwIu02OzN/vZJzfvtR13CQUqr/aY9e9cjfx8ot56Twt4/zABgfF9LrsuFBfrx8x2wqGtp4a3spOQW1rMwtZ9mnBwC45uwEMhPDuXFGEp12w4rdZTzw2g5euWN215lASqn+o0GvenXPxWlsK6pl9f7Kk/boj4oJ8Sc+PJCthTVsLahlckIY546N5pqzExERHlo8CYBpyUV8/+VtvJhdyLVZidzyr02kxQ7jgcv1dsJK9QcNetUrq0X42/VT+V9OMdOSI9xaJzMpnBW7ymjvtHPv/HEsnpJwwjJfmhrPS9mF/PbdPbR2dPLpvgrW5FVy08xk6ls7CPC1Mmb4sL7eHKWGLB2jVycVFuTLV2cnd51pcypTEsNp77Tj52Phol7OsHH07tNpbrfxyzd3M2b4MKwW4e4Xt7L4759x05MbaGnv7Fq+sLqZJcvW8dWnNna15Vc18cBrO8grb/xiG6jUEKBBr/rU5MRwAOaOjSEkwLfX5cYMD+Gbc0cD8MurJnLD9CS2FtSSEBFEaV0rT312CIDyhlYu/8tqNh6uZtX+ChpaO3hv5xEueWQVz20o4Nn1+f2+TUqd6XToRvWpSfFhnD0qgq/NTj7lst+/JI0vZyUwKiqY9JFhxIYGcMOMJH7w8jYe++QA109PYt2BKupbbdxzURqPfLiPrQW1PL7qAPERgYQH+rL2QOVJP+P1nGJGxww74awipYYS7dGrPhXga+XVb83m3LHRp1xWRBgVFQw4hoi+df5owgJ9uXveWBrbbHyYW0ZOYS0Bvha+NjsZi8BKZ9sVGSO5ZOII9pU55uABOFjRyDNrD9Npd9wmsaqxje+/tI0/fbiv/zZYqTOA9ujVoDNxZChRwX6sP1BFfnUzk+LDCAvyZfyIUJ7fVIjdwPnjYvBxnpq57mAVc8ZEc/NTGymqaWFzfg1/XDKZN7eVYLMbcgprMcZ0XeDV2Gbj7e0lRAT5cfGE2K52pbyVBr0adESEmaOjWJNXSV1LB19xzqqZlRzB7tJ6IoJ8mZwQDkBIgA9vbSvh2fX5lNe3ceOMJJ7bUIBF6LotYmVjO0U1LSRGBrH+YBXfeCabhjYbABeMi+HP108h9CTHE5Q60+nQjRqUZqVGUd7QRpvN3nWA9+xRjlM856bFYLUIVoswMzWKD3Y7hnN+d00GDy2exA8uSeN/OSVsK6pjYeZIAHIKa6loaOPbz28lJsSfV781m59eMYFV+yu5/9UdGGO6PttuNzy55hBPf3ao2w1alDpTaY9eDUqzRh+beyfT2XuflRpFkJ+VKzJGdr12+9xURoQGcPvcVBIjgwC484IxVDS0sXxrMffOH897O4+wtaCWl7ILaWjt4D+3TWf8iFDOHhVBm62T3723l3M2RnPDjCQA1h+s4sG3dgOOvxg2/+Ri/Hz6pk/U0NrBfct38N15Yxkb2/vVxkr1Jbe+vSIyX0T2ikieiNzXw+s3ish257+1IjLZ5bV7RGSXiOwUkedFRGezUqeUGh3M8BB/IoJ8SYwMBGB4aAA5P7uEiyccOz9/WnIkDy5K7wp5cAz9/HJhOpseuIj48EAmxYfx/MYCVu+v5IHLJ3S7yveOuaOZMzaaB9/aTX6VY6jnvV1HCPC18Mclk2lotbH+YFW32l7YWEDGL95n6Xt72HiomrL6Vre3K/twDW9vL+X/ntvS7VoBoNtfFUr1pVMGvYhYgUeBBcAE4HoROf5a9UPAecaYDOBB4HHnuvHAd4AsY0w6YAWu67vylbcSEW6fm8rXZqd0O1h6Oj3ro7c8zEwMp6Wjk2nJEdw4PanbMhaL8LtrMvCxCPe+sh1bp533dh7h/LThXDYpjiA/Kx/sPtJtnWfW5WOAZZ8eYMk/1jFn6ccU17bQ0NrB6znFJw3sPc6bsOdVNPLQO7u72utaOpj20Ie8s6PU7e1Tyl3u/F8zHcgzxhw0xrQDLwALXRcwxqw1xhy9c/R6wPW6dx8gUER8gCCg9ztUKOXi63NSufuisV/4fS4cP5zoYf785kuTepxELS4skJ9eMYENh6q55V+bKG9oY376CAJ8rZyXFsOK3WXYnads7jlST25pPT+4ZByr772AZV85mw67nec3FPCHD/Zx9ws5ZOcfu4n6ytyybjdc2VfWQFxYAIunxPP61pKu9113oJLKxvZerwtYmVvGuUs/otF5EFmp0+FO0McDrjcFLXK29eY24F0AY0wx8DBQAJQCdcaYD3paSURuF5FsEcmuqNA7Eam+M3tMNJsemMeY4b2PiX85K4E7zhvN6v2V+FqFC88aDsAlE2Mpq29ju/Og7P+2lmC1CJdnxJEQEcT89BHMGz+c5zbk89+NBQB8mFsGOO61+72XtnHPizm0djiGafYeaSAtNoSZKVE0tNk46DwzaPX+yq7Xe7JidxlFNS1sLajp8fVNh6vJKaw9zT2jhgp3gr6nk4x7/NtURC7AEfQ/cj6PwNH7TwFGAsEi8pWe1jXGPG6MyTLGZMXExLhTu1JuO9W58iLCfQvG89DidH40f3zX6ZYXjovFz2ph+ZYibJ123sgpZu7YaKKH+Xet+5WZo6hp7sAYw/gRIazMddyCcWVuOXUtHVQ1tfPa1mJsnXbyKhoZPyKEjETHlbrbi2oBWJPnCPo9Rxq6hn52FNXx0Nu7McZ0hXj24ROD3hjDd1/I4f7lO054raPT3vVLxh25pfV876UcbJ12t9dRg587QV8EJLo8T6CH4RcRyQCeABYaY44evboIOGSMqTDGdADLgdlfrGSl+s+NM0bx9TmpXc/Dgny5IiOOVzcX8Z/1+ZTUtXL9ceP8c8fGkB4fyi3npLAkK5G88kbyq5p4eXMhcWEBTBwZypNrDnGosol2m5202BDGxAwj0NfK9qI6Cqqaya9qZuzwYTS02iitcxzcXbbqAP9cfYj1B6vZV+bo6W92Dgt12g1ffWojS9/bw4GKJoprW9hzpJ665o5utf3g5W0s+cc6t7d/+ZYilm8pJr+6+XPtPzU4uRP0m4CxIpIiIn44Dqa+4bqAiCThCPGbjDGu15sXADNFJEgcXap5QG7flK7UwLh5djJN7Z08+NZuMhLCup31A44Dum99ew4/vuysrhk7f/vuHlbtq+BLU+P5xpxU8sob+d37ewEYNyIEH6uF9PhQx3z/eY6hyq/PSQEcwzdttk4+cd6c/Q8f7MVuYFRUEFsLaui0G57+7BCf7qvgmbWHeW+n4wCuMbDxcHVXXcW1Lby5rYTtRXUcqWvl5exC7n1l20m3dVuhY4iqqKbli+42NYicMuiNMTbgLuB9HCH9kjFml4jcISJ3OBf7GRAF/F1EckQk27nuBuAVYAuww/l5j/f9ZijVfzITw5mcEIbdwPcvGXfSYaCkqCAmxYfx7s4jhAb6cm1WEldOHknWqAhW7C7DInTNtZ+REM6uknqeWH2I5Kgg5qfHAY7hm7UHqmhq7yQ0wKfr4O4tzl84b24r4eEP9pIWO4zm9k7+9nEeiZGB+PlY2OByKuiz6/NxHutlTV4lj31ygJeyiyiqOdZbL69v7ZobyNZpZ0fx0aDXHr03cetcNWPMO8aYNGPMaGPMQ862ZcaYZc7HXzfGRBhjMp3/slzW/bkxZrwxJt0Yc5Mxpq1/NkWp/vOTKybwnXljmevGZG0v3zGLnJ9dzNafXkxSVBBWi/CHJZMJ8rOSHBXcddpnRkIY7TY7h6ua+O3VGYQF+hIXFsDeI/Ws2F1GsJ+Vb1/oOOsoOSqIec6/Fr77Yg7D/H155tbpjI4JprXDzrzxsUxJDGfdwSoefn8v1z2+jmfX53PxhFiigv14as2hrgO/R48hbDhYxezffsTvnX9p7C9vpMU5nq89eu+iUyAo5YZpyZF87+I0tyZAC/C1Eh7k123ZUVHB/PPmLH5+1cSutrNHRWAR+L/zRzMz1XEl8LgRIazJq+Lt7aWcNy6GKyY7evmZieEkRAQyOiaY8SNC+N+ds4kLC+TaaY7DZ+elxTAjNYpdJfX87eM8qpvaCfC18q3zRzNrdBS7S+uxWoT48EA+zC2juLaF/3tuCza74bkN+TS12boODAf4WnoN+m2Ftewu0Ru7n2l0CgSlBsg5Y7r/NZAQEcTqH13IyLBjF4tPS47kk70VpMeHcucFY4gLC+TXiyeRmRiOiPDGXecS4GvF6rwe4OZZyYQH+jE3LYbwIF/++tF+/u/80fzAZYhp35gG3tpeyjljojkrLoSn1hzixn+up91m53fXZHDvK9t5bWsxu0rqCQ3wIT0+jKKaZupbO1i+uYhrshIZ5u9Dc7uNrz29kRFhgbx795yB23HqC5PBeNl1VlaWyc7O9nQZSg04u93Q1G476d25TqamqZ2IYL9ubSW1LZz/8Cf84cuTiQsL4Jpl6xjm78Mzt05nalI4V/5tDdWN7XQaQ1psCCPDAvlobzm3nZvCb9/dw5jhw1j2lbP5ZG85v3rbcS7Fhh/PIzY0gKY2G797bw93XjiG4SE6u4knichm12FzVzp0o9QgYrHI5w554ISQBxgZHsjmn1zElZNHMiUpgu9eNJb/fmMGZ4+KQES4f8FZ+PpYqGnuYN744SRGBlLR0MZHe8qJDfWnuqmdK/66mr9+lEeSc06ho1f7vp5TwjPr8nllcxEA9a0dXdcBbCusPWE+H+UZGvRKDQFHf3lYLcJ3L0ojwzkjKDiGlD794QXsfXA+XzsnhYQIR5hvPFTNgvQ43r17DjNTo6hv7WDp1RnEhPjzqTPoX9nsuGj+kz0VHKpsIutXH/LW9lIKqppZ9PfPeHFTwcBuqOqRjtErpYBjVw8nRAR2tc0aHUVsaABPf20aFQ1tDA8N6Jr/Z39ZA1sKaokJ8WdzQQ2PfZJHu83O6zklVDS0YQwcrtLTNAcD7dErpbo52qMXgZkpUc7HwvBQxxj8eWkx1LV0cP0/N2C1CA8unEin3fBSdhEisHp/BW9td1w8X1yrp2kOBhr0Sqluhof442sVJo4MJSzoxOMFl02K44eXjmNUVBA3zxrFxRNGEBboWO6uC8bQZrOzpaAWcBwIVp6nQzdKqW4sFuGqyfFMHRXe4+tWi3DnBWO484IxXW0L0keQU1jLd+aN5d/r8qlr6WD8iBDt0Q8SGvRKqRP8YcnkUy/k4leL0rHZDb5WCwvSR/Dx3nKuyIjj4Q/20dRmI9hfo8aTdOhGKfWF+VgtXVM7/PzKibz17Tldt3csrdNevadp0Cul+lSgn5WYEH/iwx1n7xTXun9PXdU/NOiVUv1i5NGg1wnSPE6DXinVL4aH+GO1iJ55Mwho0Cul+oWP1cKI0AAN+kFAg14p1W/iwwMpqmlh0+FqGlo7Tr2C6hca9EqpfhMfEcjGw9V8edk6nlxzyNPlDFka9EqpfpOVHMGI0ABCA3y6bnCuBp4GvVKq39w4YxTrfzyPrORIDlY0ebqcIUuDXinV71Kigzlc1YTdPvhudDQUaNArpfpdqvMm5kfq9eIpT3Ar6EVkvojsFZE8Ebmvh9dvFJHtzn9rRWSyy2vhIvKKiOwRkVwRmdWXG6CUGvxSooMBdPjGQ04Z9CJiBR4FFgATgOtFZMJxix0CzjPGZAAPAo+7vPZn4D1jzHhgMpDbF4Urpc4cqdHDADhU2ejhSoYmd3r004E8Y8xBY0w78AKw0HUBY8xaY0yN8+l6IAFAREKBucCTzuXajTG1fVS7UuoMERvqT5CflYOV2qP3BHeCPh4odHle5GzrzW3Au87HqUAF8LSIbBWRJ0QkuKeVROR2EckWkeyKigo3ylJKnSlEhJTo4G5DN5/3xuHGGJZvKaKxzdZX5Xk9d4Jeemjr8dC5iFyAI+h/5GzyAaYCjxljpgBNwAlj/ADGmMeNMVnGmKyYmBg3ylJKnUlSooM55OzR/+PTA2T9agV1Lb1fLWu3G3YW12FM97jZVVLP917axn835Pdrvd7EnaAvAhJdnicAJccvJCIZwBPAQmNMlcu6RcaYDc7nr+AIfqXUEDN2eAiFNc08s/Ywf1ixj6b2TnaV1PW4bF1LB7c9s4kr/rqGf6/rHuhbChyjxBsPVfd7zd7CnaDfBIwVkRQR8QOuA95wXUBEkoDlwE3GmH1H240xR4BCERnnbJoH7O6TypVSZ5SbZo1i4shQfv7GLvysjujZXVJ/wnKF1c0s/vtnrN5fSWpMMA+/v5dyl9MytzrvR7vpcI2el++mUwa9McYG3AW8j+OMmZeMMbtE5A4RucO52M+AKODvIpIjItkub/Ft4DkR2Q5kAr/uyw1QSp0ZIoP9eP4bM/ny2Qn8cclkhof4s7u0nq0FNcz/0yo2HqpmS0ENX3psLZUNbTz39Rk8cXMWbTY7l/91DRf/8VN2lziW9/OxUNfSwf5y98/i2VVSx41PrOdARfd1jDHsK2s4YYhooK3ZX8kLGwvo6LT3+XuLpzeuJ1lZWSY7O/vUCyqlzlhfe3ojR+payUgI46XsIvx8LNjthhFhATz1tWmkxYYA8Pb2Ut7aXsKa/ZWkx4ex7mAV109P4vmNBTy4KJ2bZo7q9TMqG9tYvb+CRZnx3PNiDv/LKSEuLICXvjmr61aHj686wK/f2cOzt83g3LHRA7LtPVnyj3WU1Lbw6Q8vwGrp6dDoyYnIZmNMVk+v6ZWxSimPmBAXSl55Ix/mlnNeWgwzUiKZnz6Ct78zpyvkAS7PiOOxr5zNTbNGse6g4/DfwsyRxIb6s+kU4/T3vrKde17cxkvZhby36whzxkbT1Gbjl286RpDXH6xi6Xt7Afh0X/lp1W+Mod3WN73v3NJ6Nh6q5uZZoz5XyJ+KBr1SyiMmjAzFZjdUN7Xzpanx/Oe2GfzthqmEBfr2uPzXZifjZ7VgtQgZCWFMT4li7YFKWjs6WbG7jFv/tYl/fHqAz/IqKa5t4cPdZXy0pxx/HwsPvLaT1g4791ycxuIp8azJq6C1o5Of/m8nSZFBTE4MZ+2BKkpqW1jw59XsPdJ9ps3Kxjbm/2kVK3PLutr+sz6f6b/+kKKa5q62Z9Ye5t5XtvHnD/fT2uH+6aP/XncYfx8LS7IST73w5+DTL++qlFKnMHFkGAA+FuH8ccNPufzw0ABuPTeF/Komgvx8uGF6Em9uK+GxTw7w3IZ8mts7+WhP9155akwwP15wFl//dzap0cFMSQynrqWDZ9bl88zaw+wvb+RXi9KpaWrnjx/u45EV+8gtreejPeWMG3Hsr4pfv5PLniMNvLipkHlnxQLwek4Jtc0dPPDaTv51yzTabHYeejsXX6vQ1N5Jbmk9j9449ZQ99LqWDl7bWsyizHjCg/xOdze6RYNeKeURoyKDCPazkpkU3msv/nj3LRjf9XjW6CjOHRPNn1fuRwTeuPNcYsP8yStr5FBVEwXVzVyZMZKJI0P55nmpZCaEIyLMSo3C38fCIx/uw8ciXDYpjkOVTfxhxT5e3lwEwI7i2q7PWXegiuVbigkL9GX1fsdfEE1tNrYU1JAaE8yn+yp4Y1sJI0IDaO+08+iNWRRUN/PgW7v5y8r93HNxWo/bYoxBRHhjWwmtHXZunJn0+XfmKejQjVLKIywW4ZFrM/nJ5cdPneW+H1zqOHP7ummJTEoIY3hIALPHRHPjjFHcv+As0uPDEBHuX3AWCybFARDga2XW6ChaO+zMGRtNZLAfGQlhBPtZAUiOCmJ7keP8fmMMv3k3l/jwQH53TQYtHZ18llfJx3srMAYeWZLJ6JhgnltfQHa+4/z+rFER3HZuCueMieL9XUe6ajXGkH24mnabnbV5lUx7aCXv7zrCK9mFjB8RwqT4sM+9H05Fe/RKKY+5ZOKIL7R+ZmI4H9wzt2t2THednxbDJ3srWJjpmM3F12rhwrNiOVLXwryzYvntu3uoaWpnS0EN24vq+N3VGZw/LoZh/j58mFtGbXMHsaH+ZCSEcdXkeP60ch+NbTbGDh9GRLBj+GVaciR/XrmfxjYbw/x9WLG7jNv/s5n0+FDyq5ppaLXx/Ze20dhm4yeXn4VI3x+EPUp79EqpM1pabAi+1tOLsqvPTuBH88ezYNKxXzSPLJnMc1+fSYazZ729uI4/fbifpMggFk+Nx9/HynlpMTy/sZB3dx7hwvGxiAiXZ4zAGNhdWs+0lMiu95uSFIExsL2oFoA3t5cSEuBDUU0LAb5W/nPbdGx2Oz4WYfGUk00f9sVpj14pNeSEBPjyrfNHd2vzcf6ymOgM+gff2k1eeSMPf3ly1y+S712SRnJ0EAE+Vq7JSgBgzPAQxo8IYc+RBqYnHwv6zIRwwHEl75TECFbmlrFoSjw/unQ8ncYQGezHn67NpKy+jahh/v26vRr0SinlIizQl5ToYPLKG7norFiunnqstz06Zhg/vHT8CetcOXkk+8r2duvRhwX5khoTzNaCWj6JLqe5vZPLJ8URFnTswPP89Lj+3RgnDXqllDrOzNRI2m12Hv5yhltj59+Yk8p5aTHEhwd2a5+SGMEne8upb+kgKtiPGS6/CAaSBr1SSh3n/y1Mp6PTTpCfexHp52MhvYezZjKTwnl1SxF1LTX84qqJXcNDA02DXimljuNrtZz2Ad6eXJkRx6GKJq6bnthtWoeBpkGvlFL9JDzIj59d+fmvE+grenqlUkp5OQ16pZTychr0Sinl5TTolVLKy2nQK6WUl9OgV0opL6dBr5RSXk6DXimlvJwYYzxdwwlEpALI/5yrRwOVfVhOX9G6Tt9grU3rOj1a1+n7PLWNMsbE9PTCoAz6L0JEso0xWZ6u43ha1+kbrLVpXadH6zp9fV2bDt0opZSX06BXSikv541B/7inC+iF1nX6BmttWtfp0bpOX5/W5nVj9Eoppbrzxh69UkopFxr0Sinl5bwm6EVkvojsFZE8EbnPg3UkisjHIpIrIrtE5G5n+y9EpFhEcpz/LvNQfYdFZIezhmxnW6SIrBCR/c7/RgxwTeNc9kuOiNSLyHc9sc9E5CkRKReRnS5tve4fEbnf+Z3bKyKXeqC234vIHhHZLiKviUi4sz1ZRFpc9t2yAa6r15/dQO2zXup60aWmwyKS42wfyP3VW0b03/fMGHPG/wOswAEgFfADtgETPFRLHDDV+TgE2AdMAH4B/GAQ7KvDQPRxbb8D7nM+vg9Y6uGf5RFglCf2GTAXmArsPNX+cf5ctwH+QIrzO2gd4NouAXycj5e61JbsupwH9lmPP7uB3Gc91XXc638AfuaB/dVbRvTb98xbevTTgTxjzEFjTDvwArDQE4UYY0qNMVucjxuAXCDeE7WchoXAM87HzwCLPFcK84ADxpjPe2X0F2KMWQVUH9fc2/5ZCLxgjGkzxhwC8nB8FwesNmPMB8YYm/PpeiChvz7/dOo6iQHbZyerS0QEWAI83x+ffTInyYh++555S9DHA4Uuz4sYBOEqIsnAFGCDs+ku55/YTw308IgLA3wgIptF5HZnW6wxphQcX0JguIdqA7iO7v/zDYZ91tv+GWzfu1uBd12ep4jIVhH5VETmeKCenn52g2WfzQHKjDH7XdoGfH8dlxH99j3zlqCXHto8et6oiAwDXgW+a4ypBx4DRgOZQCmOPxs94RxjzFRgAXCniMz1UB0nEBE/4CrgZWfTYNlnvRk03zsReQCwAc85m0qBJGPMFOB7wH9FJHQAS+rtZzdY9tn1dO9QDPj+6iEjel20h7bT2mfeEvRFQKLL8wSgxEO1ICK+OH6AzxljlgMYY8qMMZ3GGDvwT/rxT/yTMcaUOP9bDrzmrKNMROKctccB5Z6oDccvny3GmDJnjYNin9H7/hkU3zsR+SpwBXCjcQ7qOv/Mr3I+3oxjXDdtoGo6yc/O4/tMRHyALwEvHm0b6P3VU0bQj98zbwn6TcBYEUlx9gqvA97wRCHOsb8ngVxjzB9d2uNcFlsM7Dx+3QGoLVhEQo4+xnEgbyeOffVV52JfBV4f6NqcuvWyBsM+c+pt/7wBXCci/iKSAowFNg5kYSIyH/gRcJUxptmlPUZErM7Hqc7aDg5gXb397Dy+z4CLgD3GmKKjDQO5v3rLCPrzezYQR5kH6Ej2ZTiOXh8AHvBgHefi+LNqO5Dj/HcZ8B9gh7P9DSDOA7Wl4jh6vw3YdXQ/AVHASmC/87+RHqgtCKgCwlzaBnyf4fhFUwp04OhJ3Xay/QM84PzO7QUWeKC2PBzjt0e/a8ucy17t/BlvA7YAVw5wXb3+7AZqn/VUl7P9X8Adxy07kPurt4zot++ZToGglFJezluGbpRSSvVCg14ppbycBr1SSnk5DXqllPJyGvRKKeXlNOiVUsrLadArpZSX+/8hNOvuG3L/RgAAAABJRU5ErkJggg==",
|
| 413 |
+
"text/plain": [
|
| 414 |
+
"<Figure size 432x288 with 1 Axes>"
|
| 415 |
+
]
|
| 416 |
+
},
|
| 417 |
+
"metadata": {
|
| 418 |
+
"needs_background": "light"
|
| 419 |
+
},
|
| 420 |
+
"output_type": "display_data"
|
| 421 |
+
}
|
| 422 |
+
],
|
| 423 |
+
"source": [
|
| 424 |
+
"plt.plot(torch.tensor(lossi).view(-1, 1000).mean(1))"
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"cell_type": "code",
|
| 429 |
+
"execution_count": 12,
|
| 430 |
+
"metadata": {},
|
| 431 |
+
"outputs": [],
|
| 432 |
+
"source": [
|
| 433 |
+
"# put layers into eval mode (needed for batchnorm especially)\n",
|
| 434 |
+
"for layer in model.layers:\n",
|
| 435 |
+
" layer.training = False"
|
| 436 |
+
]
|
| 437 |
+
},
|
| 438 |
+
{
|
| 439 |
+
"cell_type": "code",
|
| 440 |
+
"execution_count": 13,
|
| 441 |
+
"metadata": {},
|
| 442 |
+
"outputs": [
|
| 443 |
+
{
|
| 444 |
+
"name": "stdout",
|
| 445 |
+
"output_type": "stream",
|
| 446 |
+
"text": [
|
| 447 |
+
"train 1.7690284252166748\n",
|
| 448 |
+
"val 1.9936515092849731\n"
|
| 449 |
+
]
|
| 450 |
+
}
|
| 451 |
+
],
|
| 452 |
+
"source": [
|
| 453 |
+
"# evaluate the loss\n",
|
| 454 |
+
"@torch.no_grad() # this decorator disables gradient tracking inside pytorch\n",
|
| 455 |
+
"def split_loss(split):\n",
|
| 456 |
+
" x,y = {\n",
|
| 457 |
+
" 'train': (Xtr, Ytr),\n",
|
| 458 |
+
" 'val': (Xdev, Ydev),\n",
|
| 459 |
+
" 'test': (Xte, Yte),\n",
|
| 460 |
+
" }[split]\n",
|
| 461 |
+
" logits = model(x)\n",
|
| 462 |
+
" loss = F.cross_entropy(logits, y)\n",
|
| 463 |
+
" print(split, loss.item())\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"split_loss('train')\n",
|
| 466 |
+
"split_loss('val')"
|
| 467 |
+
]
|
| 468 |
+
},
|
| 469 |
+
{
|
| 470 |
+
"cell_type": "markdown",
|
| 471 |
+
"metadata": {},
|
| 472 |
+
"source": [
|
| 473 |
+
"### performance log\n",
|
| 474 |
+
"\n",
|
| 475 |
+
"- original (3 character context + 200 hidden neurons, 12K params): train 2.058, val 2.105\n",
|
| 476 |
+
"- context: 3 -> 8 (22K params): train 1.918, val 2.027\n",
|
| 477 |
+
"- flat -> hierarchical (22K params): train 1.941, val 2.029\n",
|
| 478 |
+
"- fix bug in batchnorm: train 1.912, val 2.022\n",
|
| 479 |
+
"- scale up the network: n_embd 24, n_hidden 128 (76K params): train 1.769, val 1.993\n"
|
| 480 |
+
]
|
| 481 |
+
},
|
| 482 |
+
{
|
| 483 |
+
"cell_type": "code",
|
| 484 |
+
"execution_count": 14,
|
| 485 |
+
"metadata": {},
|
| 486 |
+
"outputs": [
|
| 487 |
+
{
|
| 488 |
+
"name": "stdout",
|
| 489 |
+
"output_type": "stream",
|
| 490 |
+
"text": [
|
| 491 |
+
"arlij.\n",
|
| 492 |
+
"chetta.\n",
|
| 493 |
+
"heago.\n",
|
| 494 |
+
"rocklei.\n",
|
| 495 |
+
"hendrix.\n",
|
| 496 |
+
"jamylie.\n",
|
| 497 |
+
"broxin.\n",
|
| 498 |
+
"denish.\n",
|
| 499 |
+
"anslibt.\n",
|
| 500 |
+
"marianah.\n",
|
| 501 |
+
"astavia.\n",
|
| 502 |
+
"annayve.\n",
|
| 503 |
+
"aniah.\n",
|
| 504 |
+
"jayce.\n",
|
| 505 |
+
"nodiel.\n",
|
| 506 |
+
"remita.\n",
|
| 507 |
+
"niyelle.\n",
|
| 508 |
+
"jaylene.\n",
|
| 509 |
+
"aiyan.\n",
|
| 510 |
+
"aubreana.\n"
|
| 511 |
+
]
|
| 512 |
+
}
|
| 513 |
+
],
|
| 514 |
+
"source": [
|
| 515 |
+
"# sample from the model\n",
|
| 516 |
+
"for _ in range(20):\n",
|
| 517 |
+
" \n",
|
| 518 |
+
" out = []\n",
|
| 519 |
+
" context = [0] * block_size # initialize with all ...\n",
|
| 520 |
+
" while True:\n",
|
| 521 |
+
" # forward pass the neural net\n",
|
| 522 |
+
" logits = model(torch.tensor([context]))\n",
|
| 523 |
+
" probs = F.softmax(logits, dim=1)\n",
|
| 524 |
+
" # sample from the distribution\n",
|
| 525 |
+
" ix = torch.multinomial(probs, num_samples=1).item()\n",
|
| 526 |
+
" # shift the context window and track the samples\n",
|
| 527 |
+
" context = context[1:] + [ix]\n",
|
| 528 |
+
" out.append(ix)\n",
|
| 529 |
+
" # if we sample the special '.' token, break\n",
|
| 530 |
+
" if ix == 0:\n",
|
| 531 |
+
" break\n",
|
| 532 |
+
" \n",
|
| 533 |
+
" print(''.join(itos[i] for i in out)) # decode and print the generated word"
|
| 534 |
+
]
|
| 535 |
+
},
|
| 536 |
+
{
|
| 537 |
+
"cell_type": "markdown",
|
| 538 |
+
"metadata": {},
|
| 539 |
+
"source": [
|
| 540 |
+
"### Next time:\n",
|
| 541 |
+
"Why convolutions? Brief preview/hint"
|
| 542 |
+
]
|
| 543 |
+
},
|
| 544 |
+
{
|
| 545 |
+
"cell_type": "code",
|
| 546 |
+
"execution_count": 15,
|
| 547 |
+
"metadata": {},
|
| 548 |
+
"outputs": [
|
| 549 |
+
{
|
| 550 |
+
"name": "stdout",
|
| 551 |
+
"output_type": "stream",
|
| 552 |
+
"text": [
|
| 553 |
+
"........ --> d\n",
|
| 554 |
+
".......d --> i\n",
|
| 555 |
+
"......di --> o\n",
|
| 556 |
+
".....dio --> n\n",
|
| 557 |
+
"....dion --> d\n",
|
| 558 |
+
"...diond --> r\n",
|
| 559 |
+
"..diondr --> e\n",
|
| 560 |
+
".diondre --> .\n"
|
| 561 |
+
]
|
| 562 |
+
}
|
| 563 |
+
],
|
| 564 |
+
"source": [
|
| 565 |
+
"for x,y in zip(Xtr[7:15], Ytr[7:15]):\n",
|
| 566 |
+
" print(''.join(itos[ix.item()] for ix in x), '-->', itos[y.item()])"
|
| 567 |
+
]
|
| 568 |
+
},
|
| 569 |
+
{
|
| 570 |
+
"cell_type": "code",
|
| 571 |
+
"execution_count": 16,
|
| 572 |
+
"metadata": {},
|
| 573 |
+
"outputs": [
|
| 574 |
+
{
|
| 575 |
+
"data": {
|
| 576 |
+
"text/plain": [
|
| 577 |
+
"torch.Size([1, 27])"
|
| 578 |
+
]
|
| 579 |
+
},
|
| 580 |
+
"execution_count": 16,
|
| 581 |
+
"metadata": {},
|
| 582 |
+
"output_type": "execute_result"
|
| 583 |
+
}
|
| 584 |
+
],
|
| 585 |
+
"source": [
|
| 586 |
+
"# forward a single example:\n",
|
| 587 |
+
"logits = model(Xtr[[7]])\n",
|
| 588 |
+
"logits.shape"
|
| 589 |
+
]
|
| 590 |
+
},
|
| 591 |
+
{
|
| 592 |
+
"cell_type": "code",
|
| 593 |
+
"execution_count": 17,
|
| 594 |
+
"metadata": {},
|
| 595 |
+
"outputs": [
|
| 596 |
+
{
|
| 597 |
+
"data": {
|
| 598 |
+
"text/plain": [
|
| 599 |
+
"torch.Size([8, 27])"
|
| 600 |
+
]
|
| 601 |
+
},
|
| 602 |
+
"execution_count": 17,
|
| 603 |
+
"metadata": {},
|
| 604 |
+
"output_type": "execute_result"
|
| 605 |
+
}
|
| 606 |
+
],
|
| 607 |
+
"source": [
|
| 608 |
+
"# forward all of them\n",
|
| 609 |
+
"logits = torch.zeros(8, 27)\n",
|
| 610 |
+
"for i in range(8):\n",
|
| 611 |
+
" logits[i] = model(Xtr[[7+i]])\n",
|
| 612 |
+
"logits.shape"
|
| 613 |
+
]
|
| 614 |
+
},
|
| 615 |
+
{
|
| 616 |
+
"cell_type": "code",
|
| 617 |
+
"execution_count": 18,
|
| 618 |
+
"metadata": {},
|
| 619 |
+
"outputs": [],
|
| 620 |
+
"source": [
|
| 621 |
+
"# convolution is a \"for loop\"\n",
|
| 622 |
+
"# allows us to forward Linear layers efficiently over space"
|
| 623 |
+
]
|
| 624 |
+
}
|
| 625 |
+
],
|
| 626 |
+
"metadata": {
|
| 627 |
+
"kernelspec": {
|
| 628 |
+
"display_name": "caesaranalysis",
|
| 629 |
+
"language": "python",
|
| 630 |
+
"name": "caesaranalysis"
|
| 631 |
+
},
|
| 632 |
+
"language_info": {
|
| 633 |
+
"codemirror_mode": {
|
| 634 |
+
"name": "ipython",
|
| 635 |
+
"version": 3
|
| 636 |
+
},
|
| 637 |
+
"file_extension": ".py",
|
| 638 |
+
"mimetype": "text/x-python",
|
| 639 |
+
"name": "python",
|
| 640 |
+
"nbconvert_exporter": "python",
|
| 641 |
+
"pygments_lexer": "ipython3",
|
| 642 |
+
"version": "3.10.7"
|
| 643 |
+
}
|
| 644 |
+
},
|
| 645 |
+
"nbformat": 4,
|
| 646 |
+
"nbformat_minor": 4
|
| 647 |
+
}
|
CaesarAINL/Procfile
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
web: gunicorn app:app
|
CaesarAINL/amari@172.20.10.197/caesarReminder.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import requests
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
class CaesarReminder:
|
| 7 |
+
@staticmethod
|
| 8 |
+
def reminder():
|
| 9 |
+
if "CaesarReminders" in os.listdir():
|
| 10 |
+
if "caesarreminders.json" in os.listdir("CaesarReminders"):
|
| 11 |
+
with open("CaesarReminders/caesarreminders.json","r") as f:
|
| 12 |
+
reminders = json.load(f)
|
| 13 |
+
message = ""
|
| 14 |
+
for reminder in reminders["reminders"]:
|
| 15 |
+
message += f"{reminder['subject']}"
|
| 16 |
+
message += f"<br>"
|
| 17 |
+
message += f"{reminder['message']}"
|
| 18 |
+
message += f"<br>"
|
| 19 |
+
message += f"Reminder: {datetime.fromisoformat(reminder['timestep']).strftime('%m/%d/%Y, %H:%M:%S')}\n"
|
| 20 |
+
message += "<br>"
|
| 21 |
+
message += "<br>"
|
| 22 |
+
sendjson = {"raspsendemail":{"email":reminders["email"],"message":message,"subject":"Caesar Reminders"}}
|
| 23 |
+
response = requests.post("https://revisionbank-email.onrender.com/raspsendemail",json=sendjson)
|
| 24 |
+
print(response.text)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
elif "caesarreminders.json" not in os.listdir("CaesarReminders"):
|
| 28 |
+
sendjson = {"raspsendemail":{"email":"amari.lawal@gmail.com","message":"No Reminders Scheduled","subject":"Caesar Reminders"}}
|
| 29 |
+
response = requests.post("https://revisionbank-email.onrender.com/raspsendemail",json=sendjson)
|
| 30 |
+
print(response.text)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# send email saying reminder
|
| 35 |
+
elif "CaesarReminders" not in os.listdir():
|
| 36 |
+
sendjson = {"raspsendemail":{"email":"amari.lawal@gmail.com","message":"No Reminders Scheduled","subject":"Caesar Reminders"}}
|
| 37 |
+
response = requests.post("https://revisionbank-email.onrender.com/raspsendemail",json=sendjson)
|
| 38 |
+
print(response.text)
|
| 39 |
+
# Send email saying No reminders
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
constant = 60 *60
|
| 46 |
+
duration = 48 * constant # hours
|
| 47 |
+
|
| 48 |
+
#print(datetime.now().isoformat())
|
| 49 |
+
while True:
|
| 50 |
+
CaesarReminder.reminder()
|
| 51 |
+
time.sleep(duration)
|
| 52 |
+
#pass
|
CaesarAINL/amari@172.20.10.197/caesarapis/caesarReminder.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import requests
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
class CaesarReminder:
|
| 7 |
+
@staticmethod
|
| 8 |
+
def reminder():
|
| 9 |
+
if "CaesarReminders" in os.listdir():
|
| 10 |
+
if "caesarreminders.json" in os.listdir("CaesarReminders"):
|
| 11 |
+
with open("CaesarReminders/caesarreminders.json","r") as f:
|
| 12 |
+
reminders = json.load(f)
|
| 13 |
+
message = ""
|
| 14 |
+
for reminder in reminders["reminders"]:
|
| 15 |
+
message += f"{reminder['subject']}"
|
| 16 |
+
message += f"<br>"
|
| 17 |
+
message += f"{reminder['message']}"
|
| 18 |
+
message += f"<br>"
|
| 19 |
+
message += f"Reminder: {datetime.fromisoformat(reminder['timestep']).strftime('%m/%d/%Y, %H:%M:%S')}\n"
|
| 20 |
+
message += "<br>"
|
| 21 |
+
message += "<br>"
|
| 22 |
+
sendjson = {"raspsendemail":{"email":reminders["email"],"message":message,"subject":"Caesar Reminders"}}
|
| 23 |
+
response = requests.post("https://revisionbank-email.onrender.com/raspsendemail",json=sendjson)
|
| 24 |
+
print(response.text)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
elif "caesarreminders.json" not in os.listdir("CaesarReminders"):
|
| 28 |
+
sendjson = {"raspsendemail":{"email":"amari.lawal@gmail.com","message":"No Reminders Scheduled","subject":"Caesar Reminders"}}
|
| 29 |
+
response = requests.post("https://revisionbank-email.onrender.com/raspsendemail",json=sendjson)
|
| 30 |
+
print(response.text)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# send email saying reminder
|
| 35 |
+
elif "CaesarReminders" not in os.listdir():
|
| 36 |
+
sendjson = {"raspsendemail":{"email":"amari.lawal@gmail.com","message":"No Reminders Scheduled","subject":"Caesar Reminders"}}
|
| 37 |
+
response = requests.post("https://revisionbank-email.onrender.com/raspsendemail",json=sendjson)
|
| 38 |
+
print(response.text)
|
| 39 |
+
# Send email saying No reminders
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
constant = 60 *60
|
| 46 |
+
duration = 48 * constant # hours
|
| 47 |
+
|
| 48 |
+
#print(datetime.now().isoformat())
|
| 49 |
+
while True:
|
| 50 |
+
CaesarReminder.reminder()
|
| 51 |
+
time.sleep(duration)
|
| 52 |
+
#pass
|
CaesarAINL/app.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask,request
|
| 2 |
+
from flask_cors import cross_origin
|
| 3 |
+
import os
|
| 4 |
+
from caesarinfer import CaesarNL
|
| 5 |
+
app = Flask(__name__)
|
| 6 |
+
|
| 7 |
+
@app.route("/",methods=["GET"])
|
| 8 |
+
@cross_origin()
|
| 9 |
+
def caesarhome():
|
| 10 |
+
return "Caeser: How can I help you sir?"
|
| 11 |
+
|
| 12 |
+
@app.route("/caesarapi",methods=["POST","GET"])
|
| 13 |
+
@cross_origin()
|
| 14 |
+
def caesarapi():
|
| 15 |
+
if request.method == "GET":
|
| 16 |
+
return "Caeser: Hello sir, this is the CaesarAIAPI"
|
| 17 |
+
elif request.method == "POST":
|
| 18 |
+
user_input_json = request.get_json()
|
| 19 |
+
|
| 20 |
+
print("Caesar Processing...")
|
| 21 |
+
caesarResponse,intents = CaesarNL.run([user_input_json["caesarapi"]])
|
| 22 |
+
print("Caesar Processed.")
|
| 23 |
+
print(caesarResponse,"intent:",intents)
|
| 24 |
+
|
| 25 |
+
return {"caesarmessage":{"caesarResponse":caesarResponse,"intent":intents}}
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
port = int(os.environ.get('PORT', 5000)) # 80
|
| 29 |
+
app.run(debug=True,host="0.0.0.0",port=port) #
|
CaesarAINL/bert tutorial/caesarbert.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""intent_classification_with_bert.ipynb
|
| 3 |
+
|
| 4 |
+
Automatically generated by Colaboratory.
|
| 5 |
+
|
| 6 |
+
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/1gTNCoFmqJslnlc3lSbOTqsBz_UjJarS5
|
| 8 |
+
|
| 9 |
+
# Intent Classification with BERT
|
| 10 |
+
|
| 11 |
+
This notebook demonstrates the fine-tuning of BERT to perform intent classification.
|
| 12 |
+
Intent classification tries to map given instructions (sentence in natural language) to a set of predefined intents.
|
| 13 |
+
|
| 14 |
+
## What you will learn
|
| 15 |
+
|
| 16 |
+
- Load data from csv and preprocess it for training and test
|
| 17 |
+
- Load a BERT model from TensorFlow Hub
|
| 18 |
+
- Build your own model by combining BERT with a classifier
|
| 19 |
+
- Train your own model, fine-tuning BERT as part of that
|
| 20 |
+
- Save your model and use it to recognize the intend of instructions
|
| 21 |
+
|
| 22 |
+
## About BERT
|
| 23 |
+
|
| 24 |
+
[BERT](https://arxiv.org/abs/1810.04805) and other Transformer encoder architectures have been shown to be successful on a variety of tasks in NLP (natural language processing). They compute vector-space representations of natural language that are suitable for use in deep learning models. The BERT family of models uses the Transformer encoder architecture to process each token of input text in the full context of all tokens before and after, hence the name: Bidirectional Encoder Representations from Transformers.
|
| 25 |
+
|
| 26 |
+
BERT models are usually pre-trained on a large corpus of text, then fine-tuned for specific tasks.
|
| 27 |
+
|
| 28 |
+
## Setup
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import os
|
| 32 |
+
#import shutil
|
| 33 |
+
import pandas as pd
|
| 34 |
+
|
| 35 |
+
import tensorflow as tf
|
| 36 |
+
import tensorflow_hub as hub
|
| 37 |
+
import tensorflow_text as text
|
| 38 |
+
import seaborn as sns
|
| 39 |
+
from pylab import rcParams
|
| 40 |
+
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
tf.get_logger().setLevel('ERROR')
|
| 43 |
+
|
| 44 |
+
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
|
| 45 |
+
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
|
| 46 |
+
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
|
| 47 |
+
rcParams['figure.figsize'] = 12, 8
|
| 48 |
+
import warnings
|
| 49 |
+
warnings.filterwarnings("ignore")
|
| 50 |
+
|
| 51 |
+
"""## Data Access
|
| 52 |
+
The data contains various user queries categorized into seven intents. It is hosted on [GitHub](https://github.com/snipsco/nlu-benchmark/tree/master/2017-06-custom-intent-engines) and is first presented in [this paper](https://arxiv.org/abs/1805.10190). In the list below the classes and an example for each class is given:
|
| 53 |
+
|
| 54 |
+
* `class`: SearchCreativeWork - `example`:*play hell house song*
|
| 55 |
+
* `class`: GetWeather - `example`: *is it windy in boston, mb right now*
|
| 56 |
+
* `class`: BookRestaurant - `example`: *book a restaurant for eight people in six years*
|
| 57 |
+
* `class`: PlayMusic - `example`: *play the song little robin redbreast*
|
| 58 |
+
* `class`: AddToPlaylist - `example`: *add step to me to the 50 clásicos playlist*
|
| 59 |
+
* `class`: RateBook - `example`: *give 6 stars to of mice and men*
|
| 60 |
+
* `class`: SearchScreeningEvent - `example` : *find fish story*
|
| 61 |
+
|
| 62 |
+
Data can be downloaded from a Google Drive by applying [gdown](https://pypi.org/project/gdown/). In the following code cells the download is invoked only if the corresponding file, does not yet exist at the corresponding location.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
datafolder=""
|
| 66 |
+
|
| 67 |
+
trainfile=datafolder+"train.csv"
|
| 68 |
+
testfile=datafolder+"test.csv"
|
| 69 |
+
validfile=datafolder+"valid.csv"
|
| 70 |
+
|
| 71 |
+
"""Next, the downloaded .csv-files for training, validation and test are imported into pandas dataframes:"""
|
| 72 |
+
|
| 73 |
+
traindf = pd.read_csv(trainfile)
|
| 74 |
+
validdf = pd.read_csv(validfile)
|
| 75 |
+
testdf = pd.read_csv(testfile)
|
| 76 |
+
|
| 77 |
+
traindf.head()
|
| 78 |
+
|
| 79 |
+
"""Training data contains 13084 instructions:"""
|
| 80 |
+
|
| 81 |
+
traindf.shape
|
| 82 |
+
|
| 83 |
+
trainfeatures=traindf.copy()
|
| 84 |
+
trainlabels=trainfeatures.pop("intent")
|
| 85 |
+
|
| 86 |
+
trainfeatures=trainfeatures.values
|
| 87 |
+
|
| 88 |
+
"""Distribution of class-labels in training-data:"""
|
| 89 |
+
|
| 90 |
+
#chart = sns.countplot(trainlabels, palette=HAPPY_COLORS_PALETTE)
|
| 91 |
+
#plt.title("Number of texts per intent")
|
| 92 |
+
#chart
|
| 93 |
+
#`chart.set_xticklabels(chart.get_xticklabels(), rotation=30, horizontalalignment='right')
|
| 94 |
+
|
| 95 |
+
"""One-Hot-Encoding of class-labels:"""
|
| 96 |
+
|
| 97 |
+
from sklearn.preprocessing import LabelBinarizer
|
| 98 |
+
|
| 99 |
+
binarizer=LabelBinarizer()
|
| 100 |
+
trainlabels=binarizer.fit_transform(trainlabels.values)
|
| 101 |
+
|
| 102 |
+
trainlabels.shape
|
| 103 |
+
|
| 104 |
+
"""Preprocess test- and validation data in the same way as it has been done for training-data:"""
|
| 105 |
+
|
| 106 |
+
testfeatures=testdf.copy()
|
| 107 |
+
testlabels=testfeatures.pop("intent")
|
| 108 |
+
validfeatures=validdf.copy()
|
| 109 |
+
validlabels=validfeatures.pop("intent")
|
| 110 |
+
|
| 111 |
+
testfeatures=testfeatures.values
|
| 112 |
+
validfeatures=validfeatures.values
|
| 113 |
+
|
| 114 |
+
testlabels=binarizer.transform(testlabels.values)
|
| 115 |
+
validlabels=binarizer.transform(validlabels.values)
|
| 116 |
+
|
| 117 |
+
"""## Loading models from TensorFlow Hub
|
| 118 |
+
|
| 119 |
+
Here you can choose which BERT model you will load from TensorFlow Hub and fine-tune. There are multiple BERT models available.
|
| 120 |
+
|
| 121 |
+
- [BERT-Base](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3), [Uncased](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3) and [seven more models](https://tfhub.dev/google/collections/bert/1) with trained weights released by the original BERT authors.
|
| 122 |
+
- [Small BERTs](https://tfhub.dev/google/collections/bert/1) have the same general architecture but fewer and/or smaller Transformer blocks, which lets you explore tradeoffs between speed, size and quality.
|
| 123 |
+
- [ALBERT](https://tfhub.dev/google/collections/albert/1): four different sizes of "A Lite BERT" that reduces model size (but not computation time) by sharing parameters between layers.
|
| 124 |
+
- [BERT Experts](https://tfhub.dev/google/collections/experts/bert/1): eight models that all have the BERT-base architecture but offer a choice between different pre-training domains, to align more closely with the target task.
|
| 125 |
+
- [Electra](https://tfhub.dev/google/collections/electra/1) has the same architecture as BERT (in three different sizes), but gets pre-trained as a discriminator in a set-up that resembles a Generative Adversarial Network (GAN).
|
| 126 |
+
- BERT with Talking-Heads Attention and Gated GELU [[base](https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1), [large](https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_large/1)] has two improvements to the core of the Transformer architecture.
|
| 127 |
+
|
| 128 |
+
The model documentation on TensorFlow Hub has more details and references to the
|
| 129 |
+
research literature. Follow the links above, or click on the [`tfhub.dev`](http://tfhub.dev) URL
|
| 130 |
+
printed after the next cell execution.
|
| 131 |
+
|
| 132 |
+
The suggestion is to start with a Small BERT (with fewer parameters) since they are faster to fine-tune. If you like a small model but with higher accuracy, ALBERT might be your next option. If you want even better accuracy, choose
|
| 133 |
+
one of the classic BERT sizes or their recent refinements like Electra, Talking Heads, or a BERT Expert.
|
| 134 |
+
|
| 135 |
+
Aside from the models available below, there are [multiple versions](https://tfhub.dev/google/collections/transformer_encoders_text/1) of the models that are larger and can yield even better accuracy but they are too big to be fine-tuned on a single GPU. You will be able to do that on the [Solve GLUE tasks using BERT on a TPU colab](https://www.tensorflow.org/tutorials/text/solve_glue_tasks_using_bert_on_tpu).
|
| 136 |
+
|
| 137 |
+
You'll see in the code below that switching the tfhub.dev URL is enough to try any of these models, because all the differences between them are encapsulated in the SavedModels from TF Hub.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
bert_model_name = 'small_bert/bert_en_uncased_L-8_H-512_A-8'
|
| 141 |
+
map_name_to_handle = {
|
| 142 |
+
'bert_en_uncased_L-12_H-768_A-12':
|
| 143 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3',
|
| 144 |
+
'bert_en_cased_L-12_H-768_A-12':
|
| 145 |
+
'https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/3',
|
| 146 |
+
'bert_multi_cased_L-12_H-768_A-12':
|
| 147 |
+
'https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/3',
|
| 148 |
+
'small_bert/bert_en_uncased_L-2_H-128_A-2':
|
| 149 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/1',
|
| 150 |
+
'small_bert/bert_en_uncased_L-2_H-256_A-4':
|
| 151 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1',
|
| 152 |
+
'small_bert/bert_en_uncased_L-2_H-512_A-8':
|
| 153 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-512_A-8/1',
|
| 154 |
+
'small_bert/bert_en_uncased_L-2_H-768_A-12':
|
| 155 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-768_A-12/1',
|
| 156 |
+
'small_bert/bert_en_uncased_L-4_H-128_A-2':
|
| 157 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/1',
|
| 158 |
+
'small_bert/bert_en_uncased_L-4_H-256_A-4':
|
| 159 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/1',
|
| 160 |
+
'small_bert/bert_en_uncased_L-4_H-512_A-8':
|
| 161 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',
|
| 162 |
+
'small_bert/bert_en_uncased_L-4_H-768_A-12':
|
| 163 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-768_A-12/1',
|
| 164 |
+
'small_bert/bert_en_uncased_L-6_H-128_A-2':
|
| 165 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-128_A-2/1',
|
| 166 |
+
'small_bert/bert_en_uncased_L-6_H-256_A-4':
|
| 167 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1',
|
| 168 |
+
'small_bert/bert_en_uncased_L-6_H-512_A-8':
|
| 169 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-512_A-8/1',
|
| 170 |
+
'small_bert/bert_en_uncased_L-6_H-768_A-12':
|
| 171 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-768_A-12/1',
|
| 172 |
+
'small_bert/bert_en_uncased_L-8_H-128_A-2':
|
| 173 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-128_A-2/1',
|
| 174 |
+
'small_bert/bert_en_uncased_L-8_H-256_A-4':
|
| 175 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-256_A-4/1',
|
| 176 |
+
'small_bert/bert_en_uncased_L-8_H-512_A-8':
|
| 177 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1',
|
| 178 |
+
'small_bert/bert_en_uncased_L-8_H-768_A-12':
|
| 179 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-768_A-12/1',
|
| 180 |
+
'small_bert/bert_en_uncased_L-10_H-128_A-2':
|
| 181 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/1',
|
| 182 |
+
'small_bert/bert_en_uncased_L-10_H-256_A-4':
|
| 183 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-256_A-4/1',
|
| 184 |
+
'small_bert/bert_en_uncased_L-10_H-512_A-8':
|
| 185 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-512_A-8/1',
|
| 186 |
+
'small_bert/bert_en_uncased_L-10_H-768_A-12':
|
| 187 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-768_A-12/1',
|
| 188 |
+
'small_bert/bert_en_uncased_L-12_H-128_A-2':
|
| 189 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-128_A-2/1',
|
| 190 |
+
'small_bert/bert_en_uncased_L-12_H-256_A-4':
|
| 191 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1',
|
| 192 |
+
'small_bert/bert_en_uncased_L-12_H-512_A-8':
|
| 193 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-512_A-8/1',
|
| 194 |
+
'small_bert/bert_en_uncased_L-12_H-768_A-12':
|
| 195 |
+
'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-768_A-12/1',
|
| 196 |
+
'albert_en_base':
|
| 197 |
+
'https://tfhub.dev/tensorflow/albert_en_base/2',
|
| 198 |
+
'electra_small':
|
| 199 |
+
'https://tfhub.dev/google/electra_small/2',
|
| 200 |
+
'electra_base':
|
| 201 |
+
'https://tfhub.dev/google/electra_base/2',
|
| 202 |
+
'experts_pubmed':
|
| 203 |
+
'https://tfhub.dev/google/experts/bert/pubmed/2',
|
| 204 |
+
'experts_wiki_books':
|
| 205 |
+
'https://tfhub.dev/google/experts/bert/wiki_books/2',
|
| 206 |
+
'talking-heads_base':
|
| 207 |
+
'https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1',
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
map_model_to_preprocess = {
|
| 211 |
+
'bert_en_uncased_L-12_H-768_A-12':
|
| 212 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 213 |
+
'bert_en_cased_L-12_H-768_A-12':
|
| 214 |
+
'https://tfhub.dev/tensorflow/bert_en_cased_preprocess/2',
|
| 215 |
+
'small_bert/bert_en_uncased_L-2_H-128_A-2':
|
| 216 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 217 |
+
'small_bert/bert_en_uncased_L-2_H-256_A-4':
|
| 218 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 219 |
+
'small_bert/bert_en_uncased_L-2_H-512_A-8':
|
| 220 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 221 |
+
'small_bert/bert_en_uncased_L-2_H-768_A-12':
|
| 222 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 223 |
+
'small_bert/bert_en_uncased_L-4_H-128_A-2':
|
| 224 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 225 |
+
'small_bert/bert_en_uncased_L-4_H-256_A-4':
|
| 226 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 227 |
+
'small_bert/bert_en_uncased_L-4_H-512_A-8':
|
| 228 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 229 |
+
'small_bert/bert_en_uncased_L-4_H-768_A-12':
|
| 230 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 231 |
+
'small_bert/bert_en_uncased_L-6_H-128_A-2':
|
| 232 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 233 |
+
'small_bert/bert_en_uncased_L-6_H-256_A-4':
|
| 234 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 235 |
+
'small_bert/bert_en_uncased_L-6_H-512_A-8':
|
| 236 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 237 |
+
'small_bert/bert_en_uncased_L-6_H-768_A-12':
|
| 238 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 239 |
+
'small_bert/bert_en_uncased_L-8_H-128_A-2':
|
| 240 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 241 |
+
'small_bert/bert_en_uncased_L-8_H-256_A-4':
|
| 242 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 243 |
+
'small_bert/bert_en_uncased_L-8_H-512_A-8':
|
| 244 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 245 |
+
'small_bert/bert_en_uncased_L-8_H-768_A-12':
|
| 246 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 247 |
+
'small_bert/bert_en_uncased_L-10_H-128_A-2':
|
| 248 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 249 |
+
'small_bert/bert_en_uncased_L-10_H-256_A-4':
|
| 250 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 251 |
+
'small_bert/bert_en_uncased_L-10_H-512_A-8':
|
| 252 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 253 |
+
'small_bert/bert_en_uncased_L-10_H-768_A-12':
|
| 254 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 255 |
+
'small_bert/bert_en_uncased_L-12_H-128_A-2':
|
| 256 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 257 |
+
'small_bert/bert_en_uncased_L-12_H-256_A-4':
|
| 258 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 259 |
+
'small_bert/bert_en_uncased_L-12_H-512_A-8':
|
| 260 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 261 |
+
'small_bert/bert_en_uncased_L-12_H-768_A-12':
|
| 262 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 263 |
+
'bert_multi_cased_L-12_H-768_A-12':
|
| 264 |
+
'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/2',
|
| 265 |
+
'albert_en_base':
|
| 266 |
+
'https://tfhub.dev/tensorflow/albert_en_preprocess/2',
|
| 267 |
+
'electra_small':
|
| 268 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 269 |
+
'electra_base':
|
| 270 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 271 |
+
'experts_pubmed':
|
| 272 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 273 |
+
'experts_wiki_books':
|
| 274 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 275 |
+
'talking-heads_base':
|
| 276 |
+
'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
tfhub_handle_encoder = map_name_to_handle[bert_model_name]
|
| 280 |
+
tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]
|
| 281 |
+
|
| 282 |
+
print(f'BERT model selected : {tfhub_handle_encoder}')
|
| 283 |
+
print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')
|
| 284 |
+
|
| 285 |
+
"""## The preprocessing model
|
| 286 |
+
|
| 287 |
+
Text inputs need to be transformed to numeric token ids and arranged in several Tensors before being input to BERT. TensorFlow Hub provides a matching preprocessing model for each of the BERT models discussed above, which implements this transformation using TF ops from the TF.text library. It is not necessary to run pure Python code outside your TensorFlow model to preprocess text.
|
| 288 |
+
|
| 289 |
+
The preprocessing model must be the one referenced by the documentation of the BERT model, which you can read at the URL printed above. For BERT models from the drop-down above, the preprocessing model is selected automatically.
|
| 290 |
+
|
| 291 |
+
Note: You will load the preprocessing model into a [hub.KerasLayer](https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer) to compose your fine-tuned model. This is the preferred API to load a TF2-style SavedModel from TF Hub into a Keras model.
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)
|
| 295 |
+
|
| 296 |
+
"""Let's try the preprocessing model on some text and see the output:"""
|
| 297 |
+
|
| 298 |
+
trainfeatures[0]
|
| 299 |
+
|
| 300 |
+
text_test = trainfeatures[0]
|
| 301 |
+
text_preprocessed = bert_preprocess_model(text_test)
|
| 302 |
+
|
| 303 |
+
print(f'Keys : {list(text_preprocessed.keys())}')
|
| 304 |
+
print(f'Shape : {text_preprocessed["input_word_ids"].shape}')
|
| 305 |
+
print(f'Word Ids : {text_preprocessed["input_word_ids"][0, :12]}')
|
| 306 |
+
print(f'Input Mask : {text_preprocessed["input_mask"][0, :12]}')
|
| 307 |
+
print(f'Type Ids : {text_preprocessed["input_type_ids"][0, :12]}')
|
| 308 |
+
|
| 309 |
+
"""As can be seen, there are 3 outputs from the preprocessing that a BERT model would use (`input_words_id`, `input_mask` and `input_type_ids`).
|
| 310 |
+
|
| 311 |
+
Some other important points:
|
| 312 |
+
- The input is truncated to 128 tokens. The number of tokens can be customized and you can see more details on the [Solve GLUE tasks using BERT on a TPU colab](https://www.tensorflow.org/tutorials/text/solve_glue_tasks_using_bert_on_tpu).
|
| 313 |
+
- The `input_type_ids` only have one value (0) because this is a single sentence input. For a multiple sentence input, it would have one number for each input.
|
| 314 |
+
|
| 315 |
+
Since this text preprocessor is a TensorFlow model, It can be included in your model directly.
|
| 316 |
+
|
| 317 |
+
## Using the BERT model
|
| 318 |
+
|
| 319 |
+
Before putting BERT into an own model, let's take a look at its outputs. You will load it from TF Hub and see the returned values.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
bert_model = hub.KerasLayer(tfhub_handle_encoder)
|
| 323 |
+
|
| 324 |
+
bert_results = bert_model(text_preprocessed)
|
| 325 |
+
|
| 326 |
+
print(f'Loaded BERT: {tfhub_handle_encoder}')
|
| 327 |
+
print(f'Pooled Outputs Shape:{bert_results["pooled_output"].shape}')
|
| 328 |
+
print(f'Pooled Outputs Values:{bert_results["pooled_output"][0, :12]}')
|
| 329 |
+
print(f'Sequence Outputs Shape:{bert_results["sequence_output"].shape}')
|
| 330 |
+
print(f'Sequence Outputs Values:{bert_results["sequence_output"][0, :12]}')
|
| 331 |
+
|
| 332 |
+
"""The BERT models return a map with 3 important keys: `pooled_output`, `sequence_output`, `encoder_outputs`:
|
| 333 |
+
|
| 334 |
+
- `pooled_output` to represent each input sequence as a whole. The shape is `[batch_size, H]`. You can think of this as an embedding for the entire movie review.
|
| 335 |
+
- `sequence_output` represents each input token in the context. The shape is `[batch_size, seq_length, H]`. You can think of this as a contextual embedding for every token in the movie review.
|
| 336 |
+
- `encoder_outputs` are the intermediate activations of the `L` Transformer blocks. `outputs["encoder_outputs"][i]` is a Tensor of shape `[batch_size, seq_length, 1024]` with the outputs of the i-th Transformer block, for `0 <= i < L`. The last value of the list is equal to `sequence_output`.
|
| 337 |
+
|
| 338 |
+
For the fine-tuning you are going to use the `pooled_output` array.
|
| 339 |
+
|
| 340 |
+
## Define your model
|
| 341 |
+
|
| 342 |
+
You will create a very simple fine-tuned model, with the preprocessing model, the selected BERT model, one Dense and a Dropout layer.
|
| 343 |
+
|
| 344 |
+
Note: for more information about the base model's input and output you can use just follow the model's url for documentation. Here specifically you don't need to worry about it because the preprocessing model will take care of that for you.
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
def build_classifier_model():
|
| 348 |
+
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
|
| 349 |
+
preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
|
| 350 |
+
encoder_inputs = preprocessing_layer(text_input)
|
| 351 |
+
encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
|
| 352 |
+
outputs = encoder(encoder_inputs)
|
| 353 |
+
net = outputs['pooled_output']
|
| 354 |
+
net = tf.keras.layers.Dropout(0.1)(net)
|
| 355 |
+
net = tf.keras.layers.Dense(7, activation=None, name='classifier')(net)
|
| 356 |
+
return tf.keras.Model(text_input, net)
|
| 357 |
+
|
| 358 |
+
"""Let's check that the model runs with the output of the preprocessing model."""
|
| 359 |
+
|
| 360 |
+
classifier_model = build_classifier_model()
|
| 361 |
+
bert_raw_result = classifier_model(tf.constant(trainfeatures[0]))
|
| 362 |
+
print(tf.keras.activations.softmax(bert_raw_result))
|
| 363 |
+
|
| 364 |
+
"""The output is meaningless, of course, because the model has not been trained yet.
|
| 365 |
+
|
| 366 |
+
Let's take a look at the model's structure.
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
classifier_model.summary()
|
| 370 |
+
|
| 371 |
+
"""## Model training
|
| 372 |
+
|
| 373 |
+
You now have all the pieces to train a model, including the preprocessing module, BERT encoder, data, and classifier.
|
| 374 |
+
|
| 375 |
+
Since this is a non-binary classification problem and the model outputs probabilities, you'll use `losses.CategoricalCrossentropy` loss function.
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
|
| 379 |
+
metrics = tf.metrics.CategoricalAccuracy()
|
| 380 |
+
|
| 381 |
+
"""### Loading the BERT model and training
|
| 382 |
+
|
| 383 |
+
Using the `classifier_model` you created earlier, you can compile the model with the loss, metric and optimizer.
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
epochs=5
|
| 387 |
+
optimizer=tf.keras.optimizers.Adam(1e-5)
|
| 388 |
+
classifier_model.compile(optimizer=optimizer,
|
| 389 |
+
loss=loss,
|
| 390 |
+
metrics=metrics)
|
| 391 |
+
|
| 392 |
+
"""Note: training time will vary depending on the complexity of the BERT model you have selected."""
|
| 393 |
+
|
| 394 |
+
print(f'Training model with {tfhub_handle_encoder}')
|
| 395 |
+
history = classifier_model.fit(x=trainfeatures,y=trainlabels,
|
| 396 |
+
validation_data=(validfeatures,validlabels),
|
| 397 |
+
batch_size=32,
|
| 398 |
+
epochs=epochs)
|
| 399 |
+
classifier_model.save("CaesarAI.h5")
|
| 400 |
+
|
| 401 |
+
"""### Evaluate the model
|
| 402 |
+
|
| 403 |
+
Let's see how the model performs. Two values will be returned. Loss (a number which represents the error, lower values are better), and accuracy.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
loss, accuracy = classifier_model.evaluate(testfeatures,testlabels)
|
| 407 |
+
|
| 408 |
+
print(f'Loss: {loss}')
|
| 409 |
+
print(f'Accuracy: {accuracy}')
|
| 410 |
+
|
| 411 |
+
"""### Plot the accuracy and loss over time
|
| 412 |
+
|
| 413 |
+
Based on the `History` object returned by `model.fit()`. You can plot the training and validation loss for comparison, as well as the training and validation accuracy:
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
history_dict = history.history
|
| 417 |
+
print(history_dict.keys())
|
| 418 |
+
|
| 419 |
+
acc = history_dict['categorical_accuracy']
|
| 420 |
+
val_acc = history_dict['val_categorical_accuracy']
|
| 421 |
+
loss = history_dict['loss']
|
| 422 |
+
val_loss = history_dict['val_loss']
|
| 423 |
+
|
| 424 |
+
epochs = range(1, len(acc) + 1)
|
| 425 |
+
fig = plt.figure(figsize=(10, 8))
|
| 426 |
+
fig.tight_layout()
|
| 427 |
+
|
| 428 |
+
plt.subplot(2, 1, 1)
|
| 429 |
+
# "bo" is for "blue dot"
|
| 430 |
+
plt.plot(epochs, loss, 'r', label='Training loss')
|
| 431 |
+
# b is for "solid blue line"
|
| 432 |
+
plt.plot(epochs, val_loss, 'b', label='Validation loss')
|
| 433 |
+
plt.title('Training and validation loss')
|
| 434 |
+
plt.grid(True)
|
| 435 |
+
# plt.xlabel('Epochs')
|
| 436 |
+
plt.ylabel('Loss')
|
| 437 |
+
plt.legend()
|
| 438 |
+
|
| 439 |
+
plt.subplot(2, 1, 2)
|
| 440 |
+
plt.plot(epochs, acc, 'r', label='Training acc')
|
| 441 |
+
plt.plot(epochs, val_acc, 'b', label='Validation acc')
|
| 442 |
+
plt.title('Training and validation accuracy')
|
| 443 |
+
plt.grid(True)
|
| 444 |
+
plt.xlabel('Epochs')
|
| 445 |
+
plt.ylabel('Accuracy')
|
| 446 |
+
plt.legend(loc='lower right')
|
| 447 |
+
|
| 448 |
+
"""In this plot, the red lines represents the training loss and accuracy, and the blue lines are the validation loss and accuracy.
|
| 449 |
+
|
| 450 |
+
Classifying arbitrary instructions:
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
def print_my_examples(inputs, results):
|
| 454 |
+
result_for_printing = \
|
| 455 |
+
[f'input: {inputs[i]:<30} : estimated intent: {results[i]}'
|
| 456 |
+
for i in range(len(inputs))]
|
| 457 |
+
print(*result_for_printing, sep='\n')
|
| 458 |
+
print()
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
examples = [
|
| 462 |
+
'play a song from U2', # this is the same sentence tried earlier
|
| 463 |
+
'Will it rain tomorrow',
|
| 464 |
+
'I like to hear greatist hits from beastie boys',
|
| 465 |
+
'I like to book a table for 3 persons',
|
| 466 |
+
'5 stars for machines like me'
|
| 467 |
+
]
|
| 468 |
+
|
| 469 |
+
results = tf.nn.softmax(classifier_model(tf.constant(examples)))
|
| 470 |
+
|
| 471 |
+
binarizer.classes_
|
| 472 |
+
|
| 473 |
+
intents=binarizer.inverse_transform(results.numpy())
|
| 474 |
+
|
| 475 |
+
print_my_examples(examples, intents)
|
| 476 |
+
|
CaesarAINL/bert tutorial/intent_classification_with_bert.ipynb
ADDED
|
@@ -0,0 +1,1239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "IZ6SNYq_tVVC"
|
| 7 |
+
},
|
| 8 |
+
"source": [
|
| 9 |
+
"# Intent Classification with BERT\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook demonstrates the fine-tuning of BERT to perform intent classification.\n",
|
| 12 |
+
"Intent classification tries to map given instructions (sentence in natural language) to a set of predefined intents. \n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## What you will learn\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"- Load data from csv and preprocess it for training and test\n",
|
| 17 |
+
"- Load a BERT model from TensorFlow Hub\n",
|
| 18 |
+
"- Build your own model by combining BERT with a classifier\n",
|
| 19 |
+
"- Train your own model, fine-tuning BERT as part of that\n",
|
| 20 |
+
"- Save your model and use it to recognize the intend of instructions\n"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "markdown",
|
| 25 |
+
"metadata": {
|
| 26 |
+
"id": "2PHBpLPuQdmK"
|
| 27 |
+
},
|
| 28 |
+
"source": [
|
| 29 |
+
"## About BERT\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"[BERT](https://arxiv.org/abs/1810.04805) and other Transformer encoder architectures have been shown to be successful on a variety of tasks in NLP (natural language processing). They compute vector-space representations of natural language that are suitable for use in deep learning models. The BERT family of models uses the Transformer encoder architecture to process each token of input text in the full context of all tokens before and after, hence the name: Bidirectional Encoder Representations from Transformers. \n",
|
| 32 |
+
"\n",
|
| 33 |
+
"BERT models are usually pre-trained on a large corpus of text, then fine-tuned for specific tasks.\n"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "markdown",
|
| 38 |
+
"metadata": {
|
| 39 |
+
"id": "SCjmX4zTCkRK"
|
| 40 |
+
},
|
| 41 |
+
"source": [
|
| 42 |
+
"## Setup\n"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 1,
|
| 48 |
+
"metadata": {
|
| 49 |
+
"execution": {
|
| 50 |
+
"iopub.execute_input": "2021-01-13T03:07:33.561260Z",
|
| 51 |
+
"iopub.status.busy": "2021-01-13T03:07:33.560567Z",
|
| 52 |
+
"iopub.status.idle": "2021-01-13T03:07:40.852309Z",
|
| 53 |
+
"shell.execute_reply": "2021-01-13T03:07:40.851601Z"
|
| 54 |
+
},
|
| 55 |
+
"id": "_XgTpm9ZxoN9"
|
| 56 |
+
},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"import os\n",
|
| 60 |
+
"#import shutil\n",
|
| 61 |
+
"import pandas as pd\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"import tensorflow as tf\n",
|
| 64 |
+
"import tensorflow_hub as hub\n",
|
| 65 |
+
"import tensorflow_text as text\n",
|
| 66 |
+
"import seaborn as sns\n",
|
| 67 |
+
"from pylab import rcParams\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"import matplotlib.pyplot as plt\n",
|
| 70 |
+
"tf.get_logger().setLevel('ERROR')\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"sns.set(style='whitegrid', palette='muted', font_scale=1.2)\n",
|
| 73 |
+
"HAPPY_COLORS_PALETTE = [\"#01BEFE\", \"#FFDD00\", \"#FF7D00\", \"#FF006D\", \"#ADFF02\", \"#8F00FF\"]\n",
|
| 74 |
+
"sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))\n",
|
| 75 |
+
"rcParams['figure.figsize'] = 12, 8\n",
|
| 76 |
+
"import warnings\n",
|
| 77 |
+
"warnings.filterwarnings(\"ignore\")"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "markdown",
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"source": [
|
| 84 |
+
"## Data Access\n",
|
| 85 |
+
"The data contains various user queries categorized into seven intents. It is hosted on [GitHub](https://github.com/snipsco/nlu-benchmark/tree/master/2017-06-custom-intent-engines) and is first presented in [this paper](https://arxiv.org/abs/1805.10190). In the list below the classes and an example for each class is given:\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"* `class`: SearchCreativeWork - `example`:*play hell house song*\n",
|
| 88 |
+
"* `class`: GetWeather - `example`: *is it windy in boston, mb right now*\n",
|
| 89 |
+
"* `class`: BookRestaurant - `example`: *book a restaurant for eight people in six years*\n",
|
| 90 |
+
"* `class`: PlayMusic - `example`: *play the song little robin redbreast*\n",
|
| 91 |
+
"* `class`: AddToPlaylist - `example`: *add step to me to the 50 clásicos playlist*\n",
|
| 92 |
+
"* `class`: RateBook - `example`: *give 6 stars to of mice and men*\n",
|
| 93 |
+
"* `class`: SearchScreeningEvent - `example` : *find fish story*"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"cell_type": "markdown",
|
| 98 |
+
"metadata": {},
|
| 99 |
+
"source": [
|
| 100 |
+
"Data can be downloaded from a Google Drive by applying [gdown](https://pypi.org/project/gdown/). In the following code cells the download is invoked only if the corresponding file, does not yet exist at the corresponding location."
|
| 101 |
+
]
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"cell_type": "code",
|
| 105 |
+
"execution_count": 2,
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"outputs": [],
|
| 108 |
+
"source": [
|
| 109 |
+
"datafolder=\"\""
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"cell_type": "code",
|
| 114 |
+
"execution_count": 3,
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"outputs": [],
|
| 117 |
+
"source": [
|
| 118 |
+
"trainfile=datafolder+\"train.csv\"\n",
|
| 119 |
+
"testfile=datafolder+\"test.csv\"\n",
|
| 120 |
+
"validfile=datafolder+\"valid.csv\""
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "markdown",
|
| 125 |
+
"metadata": {},
|
| 126 |
+
"source": [
|
| 127 |
+
"Next, the downloaded .csv-files for training, validation and test are imported into pandas dataframes:"
|
| 128 |
+
]
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"cell_type": "code",
|
| 132 |
+
"execution_count": 4,
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"outputs": [],
|
| 135 |
+
"source": [
|
| 136 |
+
"traindf = pd.read_csv(trainfile)\n",
|
| 137 |
+
"validdf = pd.read_csv(validfile)\n",
|
| 138 |
+
"testdf = pd.read_csv(testfile)"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"cell_type": "code",
|
| 143 |
+
"execution_count": 5,
|
| 144 |
+
"metadata": {},
|
| 145 |
+
"outputs": [
|
| 146 |
+
{
|
| 147 |
+
"data": {
|
| 148 |
+
"text/html": [
|
| 149 |
+
"<div>\n",
|
| 150 |
+
"<style scoped>\n",
|
| 151 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 152 |
+
" vertical-align: middle;\n",
|
| 153 |
+
" }\n",
|
| 154 |
+
"\n",
|
| 155 |
+
" .dataframe tbody tr th {\n",
|
| 156 |
+
" vertical-align: top;\n",
|
| 157 |
+
" }\n",
|
| 158 |
+
"\n",
|
| 159 |
+
" .dataframe thead th {\n",
|
| 160 |
+
" text-align: right;\n",
|
| 161 |
+
" }\n",
|
| 162 |
+
"</style>\n",
|
| 163 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 164 |
+
" <thead>\n",
|
| 165 |
+
" <tr style=\"text-align: right;\">\n",
|
| 166 |
+
" <th></th>\n",
|
| 167 |
+
" <th>text</th>\n",
|
| 168 |
+
" <th>intent</th>\n",
|
| 169 |
+
" </tr>\n",
|
| 170 |
+
" </thead>\n",
|
| 171 |
+
" <tbody>\n",
|
| 172 |
+
" <tr>\n",
|
| 173 |
+
" <th>0</th>\n",
|
| 174 |
+
" <td>listen to westbam alumb allergic on google music</td>\n",
|
| 175 |
+
" <td>PlayMusic</td>\n",
|
| 176 |
+
" </tr>\n",
|
| 177 |
+
" <tr>\n",
|
| 178 |
+
" <th>1</th>\n",
|
| 179 |
+
" <td>add step to me to the 50 clásicos playlist</td>\n",
|
| 180 |
+
" <td>AddToPlaylist</td>\n",
|
| 181 |
+
" </tr>\n",
|
| 182 |
+
" <tr>\n",
|
| 183 |
+
" <th>2</th>\n",
|
| 184 |
+
" <td>i give this current textbook a rating value of...</td>\n",
|
| 185 |
+
" <td>RateBook</td>\n",
|
| 186 |
+
" </tr>\n",
|
| 187 |
+
" <tr>\n",
|
| 188 |
+
" <th>3</th>\n",
|
| 189 |
+
" <td>play the song little robin redbreast</td>\n",
|
| 190 |
+
" <td>PlayMusic</td>\n",
|
| 191 |
+
" </tr>\n",
|
| 192 |
+
" <tr>\n",
|
| 193 |
+
" <th>4</th>\n",
|
| 194 |
+
" <td>please add iris dement to my playlist this is ...</td>\n",
|
| 195 |
+
" <td>AddToPlaylist</td>\n",
|
| 196 |
+
" </tr>\n",
|
| 197 |
+
" </tbody>\n",
|
| 198 |
+
"</table>\n",
|
| 199 |
+
"</div>"
|
| 200 |
+
],
|
| 201 |
+
"text/plain": [
|
| 202 |
+
" text intent\n",
|
| 203 |
+
"0 listen to westbam alumb allergic on google music PlayMusic\n",
|
| 204 |
+
"1 add step to me to the 50 clásicos playlist AddToPlaylist\n",
|
| 205 |
+
"2 i give this current textbook a rating value of... RateBook\n",
|
| 206 |
+
"3 play the song little robin redbreast PlayMusic\n",
|
| 207 |
+
"4 please add iris dement to my playlist this is ... AddToPlaylist"
|
| 208 |
+
]
|
| 209 |
+
},
|
| 210 |
+
"execution_count": 5,
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"output_type": "execute_result"
|
| 213 |
+
}
|
| 214 |
+
],
|
| 215 |
+
"source": [
|
| 216 |
+
"traindf.head()"
|
| 217 |
+
]
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"cell_type": "markdown",
|
| 221 |
+
"metadata": {},
|
| 222 |
+
"source": [
|
| 223 |
+
"Training data contains 13084 instructions:"
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"cell_type": "code",
|
| 228 |
+
"execution_count": 6,
|
| 229 |
+
"metadata": {
|
| 230 |
+
"scrolled": true
|
| 231 |
+
},
|
| 232 |
+
"outputs": [
|
| 233 |
+
{
|
| 234 |
+
"data": {
|
| 235 |
+
"text/plain": [
|
| 236 |
+
"(13084, 2)"
|
| 237 |
+
]
|
| 238 |
+
},
|
| 239 |
+
"execution_count": 6,
|
| 240 |
+
"metadata": {},
|
| 241 |
+
"output_type": "execute_result"
|
| 242 |
+
}
|
| 243 |
+
],
|
| 244 |
+
"source": [
|
| 245 |
+
"traindf.shape"
|
| 246 |
+
]
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"cell_type": "code",
|
| 250 |
+
"execution_count": 7,
|
| 251 |
+
"metadata": {},
|
| 252 |
+
"outputs": [],
|
| 253 |
+
"source": [
|
| 254 |
+
"trainfeatures=traindf.copy()\n",
|
| 255 |
+
"trainlabels=trainfeatures.pop(\"intent\")"
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "code",
|
| 260 |
+
"execution_count": 8,
|
| 261 |
+
"metadata": {},
|
| 262 |
+
"outputs": [],
|
| 263 |
+
"source": [
|
| 264 |
+
"trainfeatures=trainfeatures.values"
|
| 265 |
+
]
|
| 266 |
+
},
|
| 267 |
+
{
|
| 268 |
+
"cell_type": "markdown",
|
| 269 |
+
"metadata": {},
|
| 270 |
+
"source": [
|
| 271 |
+
"Distribution of class-labels in training-data:"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "code",
|
| 276 |
+
"execution_count": 9,
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"outputs": [],
|
| 279 |
+
"source": [
|
| 280 |
+
"#chart = sns.countplot(trainlabels, palette=HAPPY_COLORS_PALETTE)\n",
|
| 281 |
+
"#plt.title(\"Number of texts per intent\")\n",
|
| 282 |
+
"#chart\n",
|
| 283 |
+
"#`chart.set_xticklabels(chart.get_xticklabels(), rotation=30, horizontalalignment='right')"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"cell_type": "markdown",
|
| 288 |
+
"metadata": {},
|
| 289 |
+
"source": [
|
| 290 |
+
"One-Hot-Encoding of class-labels:"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "code",
|
| 295 |
+
"execution_count": 10,
|
| 296 |
+
"metadata": {},
|
| 297 |
+
"outputs": [],
|
| 298 |
+
"source": [
|
| 299 |
+
"from sklearn.preprocessing import LabelBinarizer"
|
| 300 |
+
]
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
"cell_type": "code",
|
| 304 |
+
"execution_count": 11,
|
| 305 |
+
"metadata": {},
|
| 306 |
+
"outputs": [],
|
| 307 |
+
"source": [
|
| 308 |
+
"binarizer=LabelBinarizer()\n",
|
| 309 |
+
"trainlabels=binarizer.fit_transform(trainlabels.values)"
|
| 310 |
+
]
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"cell_type": "code",
|
| 314 |
+
"execution_count": 12,
|
| 315 |
+
"metadata": {},
|
| 316 |
+
"outputs": [
|
| 317 |
+
{
|
| 318 |
+
"data": {
|
| 319 |
+
"text/plain": [
|
| 320 |
+
"(13084, 7)"
|
| 321 |
+
]
|
| 322 |
+
},
|
| 323 |
+
"execution_count": 12,
|
| 324 |
+
"metadata": {},
|
| 325 |
+
"output_type": "execute_result"
|
| 326 |
+
}
|
| 327 |
+
],
|
| 328 |
+
"source": [
|
| 329 |
+
"trainlabels.shape"
|
| 330 |
+
]
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"cell_type": "markdown",
|
| 334 |
+
"metadata": {},
|
| 335 |
+
"source": [
|
| 336 |
+
"Preprocess test- and validation data in the same way as it has been done for training-data:"
|
| 337 |
+
]
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"cell_type": "code",
|
| 341 |
+
"execution_count": 13,
|
| 342 |
+
"metadata": {},
|
| 343 |
+
"outputs": [],
|
| 344 |
+
"source": [
|
| 345 |
+
"testfeatures=testdf.copy()\n",
|
| 346 |
+
"testlabels=testfeatures.pop(\"intent\")\n",
|
| 347 |
+
"validfeatures=validdf.copy()\n",
|
| 348 |
+
"validlabels=validfeatures.pop(\"intent\")\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"testfeatures=testfeatures.values\n",
|
| 351 |
+
"validfeatures=validfeatures.values\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"testlabels=binarizer.transform(testlabels.values)\n",
|
| 354 |
+
"validlabels=binarizer.transform(validlabels.values)"
|
| 355 |
+
]
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"cell_type": "markdown",
|
| 359 |
+
"metadata": {
|
| 360 |
+
"id": "dX8FtlpGJRE6"
|
| 361 |
+
},
|
| 362 |
+
"source": [
|
| 363 |
+
"## Loading models from TensorFlow Hub\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"Here you can choose which BERT model you will load from TensorFlow Hub and fine-tune. There are multiple BERT models available.\n",
|
| 366 |
+
"\n",
|
| 367 |
+
" - [BERT-Base](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3), [Uncased](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3) and [seven more models](https://tfhub.dev/google/collections/bert/1) with trained weights released by the original BERT authors.\n",
|
| 368 |
+
" - [Small BERTs](https://tfhub.dev/google/collections/bert/1) have the same general architecture but fewer and/or smaller Transformer blocks, which lets you explore tradeoffs between speed, size and quality.\n",
|
| 369 |
+
" - [ALBERT](https://tfhub.dev/google/collections/albert/1): four different sizes of \"A Lite BERT\" that reduces model size (but not computation time) by sharing parameters between layers.\n",
|
| 370 |
+
" - [BERT Experts](https://tfhub.dev/google/collections/experts/bert/1): eight models that all have the BERT-base architecture but offer a choice between different pre-training domains, to align more closely with the target task.\n",
|
| 371 |
+
" - [Electra](https://tfhub.dev/google/collections/electra/1) has the same architecture as BERT (in three different sizes), but gets pre-trained as a discriminator in a set-up that resembles a Generative Adversarial Network (GAN).\n",
|
| 372 |
+
" - BERT with Talking-Heads Attention and Gated GELU [[base](https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1), [large](https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_large/1)] has two improvements to the core of the Transformer architecture.\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"The model documentation on TensorFlow Hub has more details and references to the\n",
|
| 375 |
+
"research literature. Follow the links above, or click on the [`tfhub.dev`](http://tfhub.dev) URL\n",
|
| 376 |
+
"printed after the next cell execution.\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"The suggestion is to start with a Small BERT (with fewer parameters) since they are faster to fine-tune. If you like a small model but with higher accuracy, ALBERT might be your next option. If you want even better accuracy, choose\n",
|
| 379 |
+
"one of the classic BERT sizes or their recent refinements like Electra, Talking Heads, or a BERT Expert.\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"Aside from the models available below, there are [multiple versions](https://tfhub.dev/google/collections/transformer_encoders_text/1) of the models that are larger and can yield even better accuracy but they are too big to be fine-tuned on a single GPU. You will be able to do that on the [Solve GLUE tasks using BERT on a TPU colab](https://www.tensorflow.org/tutorials/text/solve_glue_tasks_using_bert_on_tpu).\n",
|
| 382 |
+
"\n",
|
| 383 |
+
"You'll see in the code below that switching the tfhub.dev URL is enough to try any of these models, because all the differences between them are encapsulated in the SavedModels from TF Hub."
|
| 384 |
+
]
|
| 385 |
+
},
|
| 386 |
+
{
|
| 387 |
+
"cell_type": "code",
|
| 388 |
+
"execution_count": 14,
|
| 389 |
+
"metadata": {
|
| 390 |
+
"execution": {
|
| 391 |
+
"iopub.execute_input": "2021-01-13T03:08:07.421163Z",
|
| 392 |
+
"iopub.status.busy": "2021-01-13T03:08:07.420022Z",
|
| 393 |
+
"iopub.status.idle": "2021-01-13T03:08:07.423061Z",
|
| 394 |
+
"shell.execute_reply": "2021-01-13T03:08:07.423513Z"
|
| 395 |
+
},
|
| 396 |
+
"id": "y8_ctG55-uTX"
|
| 397 |
+
},
|
| 398 |
+
"outputs": [
|
| 399 |
+
{
|
| 400 |
+
"name": "stdout",
|
| 401 |
+
"output_type": "stream",
|
| 402 |
+
"text": [
|
| 403 |
+
"BERT model selected : https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1\n",
|
| 404 |
+
"Preprocess model auto-selected: https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2\n"
|
| 405 |
+
]
|
| 406 |
+
}
|
| 407 |
+
],
|
| 408 |
+
"source": [
|
| 409 |
+
"bert_model_name = 'small_bert/bert_en_uncased_L-8_H-512_A-8' \n",
|
| 410 |
+
"map_name_to_handle = {\n",
|
| 411 |
+
" 'bert_en_uncased_L-12_H-768_A-12':\n",
|
| 412 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3',\n",
|
| 413 |
+
" 'bert_en_cased_L-12_H-768_A-12':\n",
|
| 414 |
+
" 'https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/3',\n",
|
| 415 |
+
" 'bert_multi_cased_L-12_H-768_A-12':\n",
|
| 416 |
+
" 'https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/3',\n",
|
| 417 |
+
" 'small_bert/bert_en_uncased_L-2_H-128_A-2':\n",
|
| 418 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/1',\n",
|
| 419 |
+
" 'small_bert/bert_en_uncased_L-2_H-256_A-4':\n",
|
| 420 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1',\n",
|
| 421 |
+
" 'small_bert/bert_en_uncased_L-2_H-512_A-8':\n",
|
| 422 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-512_A-8/1',\n",
|
| 423 |
+
" 'small_bert/bert_en_uncased_L-2_H-768_A-12':\n",
|
| 424 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-768_A-12/1',\n",
|
| 425 |
+
" 'small_bert/bert_en_uncased_L-4_H-128_A-2':\n",
|
| 426 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/1',\n",
|
| 427 |
+
" 'small_bert/bert_en_uncased_L-4_H-256_A-4':\n",
|
| 428 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/1',\n",
|
| 429 |
+
" 'small_bert/bert_en_uncased_L-4_H-512_A-8':\n",
|
| 430 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',\n",
|
| 431 |
+
" 'small_bert/bert_en_uncased_L-4_H-768_A-12':\n",
|
| 432 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-768_A-12/1',\n",
|
| 433 |
+
" 'small_bert/bert_en_uncased_L-6_H-128_A-2':\n",
|
| 434 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-128_A-2/1',\n",
|
| 435 |
+
" 'small_bert/bert_en_uncased_L-6_H-256_A-4':\n",
|
| 436 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1',\n",
|
| 437 |
+
" 'small_bert/bert_en_uncased_L-6_H-512_A-8':\n",
|
| 438 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-512_A-8/1',\n",
|
| 439 |
+
" 'small_bert/bert_en_uncased_L-6_H-768_A-12':\n",
|
| 440 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-768_A-12/1',\n",
|
| 441 |
+
" 'small_bert/bert_en_uncased_L-8_H-128_A-2':\n",
|
| 442 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-128_A-2/1',\n",
|
| 443 |
+
" 'small_bert/bert_en_uncased_L-8_H-256_A-4':\n",
|
| 444 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-256_A-4/1',\n",
|
| 445 |
+
" 'small_bert/bert_en_uncased_L-8_H-512_A-8':\n",
|
| 446 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1',\n",
|
| 447 |
+
" 'small_bert/bert_en_uncased_L-8_H-768_A-12':\n",
|
| 448 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-768_A-12/1',\n",
|
| 449 |
+
" 'small_bert/bert_en_uncased_L-10_H-128_A-2':\n",
|
| 450 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/1',\n",
|
| 451 |
+
" 'small_bert/bert_en_uncased_L-10_H-256_A-4':\n",
|
| 452 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-256_A-4/1',\n",
|
| 453 |
+
" 'small_bert/bert_en_uncased_L-10_H-512_A-8':\n",
|
| 454 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-512_A-8/1',\n",
|
| 455 |
+
" 'small_bert/bert_en_uncased_L-10_H-768_A-12':\n",
|
| 456 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-768_A-12/1',\n",
|
| 457 |
+
" 'small_bert/bert_en_uncased_L-12_H-128_A-2':\n",
|
| 458 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-128_A-2/1',\n",
|
| 459 |
+
" 'small_bert/bert_en_uncased_L-12_H-256_A-4':\n",
|
| 460 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1',\n",
|
| 461 |
+
" 'small_bert/bert_en_uncased_L-12_H-512_A-8':\n",
|
| 462 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-512_A-8/1',\n",
|
| 463 |
+
" 'small_bert/bert_en_uncased_L-12_H-768_A-12':\n",
|
| 464 |
+
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-768_A-12/1',\n",
|
| 465 |
+
" 'albert_en_base':\n",
|
| 466 |
+
" 'https://tfhub.dev/tensorflow/albert_en_base/2',\n",
|
| 467 |
+
" 'electra_small':\n",
|
| 468 |
+
" 'https://tfhub.dev/google/electra_small/2',\n",
|
| 469 |
+
" 'electra_base':\n",
|
| 470 |
+
" 'https://tfhub.dev/google/electra_base/2',\n",
|
| 471 |
+
" 'experts_pubmed':\n",
|
| 472 |
+
" 'https://tfhub.dev/google/experts/bert/pubmed/2',\n",
|
| 473 |
+
" 'experts_wiki_books':\n",
|
| 474 |
+
" 'https://tfhub.dev/google/experts/bert/wiki_books/2',\n",
|
| 475 |
+
" 'talking-heads_base':\n",
|
| 476 |
+
" 'https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1',\n",
|
| 477 |
+
"}\n",
|
| 478 |
+
"\n",
|
| 479 |
+
"map_model_to_preprocess = {\n",
|
| 480 |
+
" 'bert_en_uncased_L-12_H-768_A-12':\n",
|
| 481 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 482 |
+
" 'bert_en_cased_L-12_H-768_A-12':\n",
|
| 483 |
+
" 'https://tfhub.dev/tensorflow/bert_en_cased_preprocess/2',\n",
|
| 484 |
+
" 'small_bert/bert_en_uncased_L-2_H-128_A-2':\n",
|
| 485 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 486 |
+
" 'small_bert/bert_en_uncased_L-2_H-256_A-4':\n",
|
| 487 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 488 |
+
" 'small_bert/bert_en_uncased_L-2_H-512_A-8':\n",
|
| 489 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 490 |
+
" 'small_bert/bert_en_uncased_L-2_H-768_A-12':\n",
|
| 491 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 492 |
+
" 'small_bert/bert_en_uncased_L-4_H-128_A-2':\n",
|
| 493 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 494 |
+
" 'small_bert/bert_en_uncased_L-4_H-256_A-4':\n",
|
| 495 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 496 |
+
" 'small_bert/bert_en_uncased_L-4_H-512_A-8':\n",
|
| 497 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 498 |
+
" 'small_bert/bert_en_uncased_L-4_H-768_A-12':\n",
|
| 499 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 500 |
+
" 'small_bert/bert_en_uncased_L-6_H-128_A-2':\n",
|
| 501 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 502 |
+
" 'small_bert/bert_en_uncased_L-6_H-256_A-4':\n",
|
| 503 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 504 |
+
" 'small_bert/bert_en_uncased_L-6_H-512_A-8':\n",
|
| 505 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 506 |
+
" 'small_bert/bert_en_uncased_L-6_H-768_A-12':\n",
|
| 507 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 508 |
+
" 'small_bert/bert_en_uncased_L-8_H-128_A-2':\n",
|
| 509 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 510 |
+
" 'small_bert/bert_en_uncased_L-8_H-256_A-4':\n",
|
| 511 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 512 |
+
" 'small_bert/bert_en_uncased_L-8_H-512_A-8':\n",
|
| 513 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 514 |
+
" 'small_bert/bert_en_uncased_L-8_H-768_A-12':\n",
|
| 515 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 516 |
+
" 'small_bert/bert_en_uncased_L-10_H-128_A-2':\n",
|
| 517 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 518 |
+
" 'small_bert/bert_en_uncased_L-10_H-256_A-4':\n",
|
| 519 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 520 |
+
" 'small_bert/bert_en_uncased_L-10_H-512_A-8':\n",
|
| 521 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 522 |
+
" 'small_bert/bert_en_uncased_L-10_H-768_A-12':\n",
|
| 523 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 524 |
+
" 'small_bert/bert_en_uncased_L-12_H-128_A-2':\n",
|
| 525 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 526 |
+
" 'small_bert/bert_en_uncased_L-12_H-256_A-4':\n",
|
| 527 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 528 |
+
" 'small_bert/bert_en_uncased_L-12_H-512_A-8':\n",
|
| 529 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 530 |
+
" 'small_bert/bert_en_uncased_L-12_H-768_A-12':\n",
|
| 531 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 532 |
+
" 'bert_multi_cased_L-12_H-768_A-12':\n",
|
| 533 |
+
" 'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/2',\n",
|
| 534 |
+
" 'albert_en_base':\n",
|
| 535 |
+
" 'https://tfhub.dev/tensorflow/albert_en_preprocess/2',\n",
|
| 536 |
+
" 'electra_small':\n",
|
| 537 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 538 |
+
" 'electra_base':\n",
|
| 539 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 540 |
+
" 'experts_pubmed':\n",
|
| 541 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 542 |
+
" 'experts_wiki_books':\n",
|
| 543 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 544 |
+
" 'talking-heads_base':\n",
|
| 545 |
+
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2',\n",
|
| 546 |
+
"}\n",
|
| 547 |
+
"\n",
|
| 548 |
+
"tfhub_handle_encoder = map_name_to_handle[bert_model_name]\n",
|
| 549 |
+
"tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"print(f'BERT model selected : {tfhub_handle_encoder}')\n",
|
| 552 |
+
"print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')"
|
| 553 |
+
]
|
| 554 |
+
},
|
| 555 |
+
{
|
| 556 |
+
"cell_type": "markdown",
|
| 557 |
+
"metadata": {
|
| 558 |
+
"id": "7WrcxxTRDdHi"
|
| 559 |
+
},
|
| 560 |
+
"source": [
|
| 561 |
+
"## The preprocessing model\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"Text inputs need to be transformed to numeric token ids and arranged in several Tensors before being input to BERT. TensorFlow Hub provides a matching preprocessing model for each of the BERT models discussed above, which implements this transformation using TF ops from the TF.text library. It is not necessary to run pure Python code outside your TensorFlow model to preprocess text.\n",
|
| 564 |
+
"\n",
|
| 565 |
+
"The preprocessing model must be the one referenced by the documentation of the BERT model, which you can read at the URL printed above. For BERT models from the drop-down above, the preprocessing model is selected automatically.\n",
|
| 566 |
+
"\n",
|
| 567 |
+
"Note: You will load the preprocessing model into a [hub.KerasLayer](https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer) to compose your fine-tuned model. This is the preferred API to load a TF2-style SavedModel from TF Hub into a Keras model."
|
| 568 |
+
]
|
| 569 |
+
},
|
| 570 |
+
{
|
| 571 |
+
"cell_type": "code",
|
| 572 |
+
"execution_count": 15,
|
| 573 |
+
"metadata": {
|
| 574 |
+
"execution": {
|
| 575 |
+
"iopub.execute_input": "2021-01-13T03:08:07.428866Z",
|
| 576 |
+
"iopub.status.busy": "2021-01-13T03:08:07.427625Z",
|
| 577 |
+
"iopub.status.idle": "2021-01-13T03:08:10.467605Z",
|
| 578 |
+
"shell.execute_reply": "2021-01-13T03:08:10.468143Z"
|
| 579 |
+
},
|
| 580 |
+
"id": "0SQi-jWd_jzq"
|
| 581 |
+
},
|
| 582 |
+
"outputs": [],
|
| 583 |
+
"source": [
|
| 584 |
+
"bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)"
|
| 585 |
+
]
|
| 586 |
+
},
|
| 587 |
+
{
|
| 588 |
+
"cell_type": "markdown",
|
| 589 |
+
"metadata": {
|
| 590 |
+
"id": "x4naBiEE_cZX"
|
| 591 |
+
},
|
| 592 |
+
"source": [
|
| 593 |
+
"Let's try the preprocessing model on some text and see the output:"
|
| 594 |
+
]
|
| 595 |
+
},
|
| 596 |
+
{
|
| 597 |
+
"cell_type": "code",
|
| 598 |
+
"execution_count": 16,
|
| 599 |
+
"metadata": {},
|
| 600 |
+
"outputs": [
|
| 601 |
+
{
|
| 602 |
+
"data": {
|
| 603 |
+
"text/plain": [
|
| 604 |
+
"array(['listen to westbam alumb allergic on google music'], dtype=object)"
|
| 605 |
+
]
|
| 606 |
+
},
|
| 607 |
+
"execution_count": 16,
|
| 608 |
+
"metadata": {},
|
| 609 |
+
"output_type": "execute_result"
|
| 610 |
+
}
|
| 611 |
+
],
|
| 612 |
+
"source": [
|
| 613 |
+
"trainfeatures[0]"
|
| 614 |
+
]
|
| 615 |
+
},
|
| 616 |
+
{
|
| 617 |
+
"cell_type": "code",
|
| 618 |
+
"execution_count": 17,
|
| 619 |
+
"metadata": {
|
| 620 |
+
"execution": {
|
| 621 |
+
"iopub.execute_input": "2021-01-13T03:08:10.482448Z",
|
| 622 |
+
"iopub.status.busy": "2021-01-13T03:08:10.478061Z",
|
| 623 |
+
"iopub.status.idle": "2021-01-13T03:08:10.652900Z",
|
| 624 |
+
"shell.execute_reply": "2021-01-13T03:08:10.652248Z"
|
| 625 |
+
},
|
| 626 |
+
"id": "r9-zCzJpnuwS"
|
| 627 |
+
},
|
| 628 |
+
"outputs": [
|
| 629 |
+
{
|
| 630 |
+
"name": "stdout",
|
| 631 |
+
"output_type": "stream",
|
| 632 |
+
"text": [
|
| 633 |
+
"Keys : ['input_type_ids', 'input_mask', 'input_word_ids']\n",
|
| 634 |
+
"Shape : (1, 128)\n",
|
| 635 |
+
"Word Ids : [ 101 4952 2000 2225 3676 2213 2632 25438 27395 2006 8224 2189]\n",
|
| 636 |
+
"Input Mask : [1 1 1 1 1 1 1 1 1 1 1 1]\n",
|
| 637 |
+
"Type Ids : [0 0 0 0 0 0 0 0 0 0 0 0]\n"
|
| 638 |
+
]
|
| 639 |
+
}
|
| 640 |
+
],
|
| 641 |
+
"source": [
|
| 642 |
+
"text_test = trainfeatures[0]\n",
|
| 643 |
+
"text_preprocessed = bert_preprocess_model(text_test)\n",
|
| 644 |
+
"\n",
|
| 645 |
+
"print(f'Keys : {list(text_preprocessed.keys())}')\n",
|
| 646 |
+
"print(f'Shape : {text_preprocessed[\"input_word_ids\"].shape}')\n",
|
| 647 |
+
"print(f'Word Ids : {text_preprocessed[\"input_word_ids\"][0, :12]}')\n",
|
| 648 |
+
"print(f'Input Mask : {text_preprocessed[\"input_mask\"][0, :12]}')\n",
|
| 649 |
+
"print(f'Type Ids : {text_preprocessed[\"input_type_ids\"][0, :12]}')"
|
| 650 |
+
]
|
| 651 |
+
},
|
| 652 |
+
{
|
| 653 |
+
"cell_type": "markdown",
|
| 654 |
+
"metadata": {
|
| 655 |
+
"id": "EqL7ihkN_862"
|
| 656 |
+
},
|
| 657 |
+
"source": [
|
| 658 |
+
"As can be seen, there are 3 outputs from the preprocessing that a BERT model would use (`input_words_id`, `input_mask` and `input_type_ids`).\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"Some other important points:\n",
|
| 661 |
+
"- The input is truncated to 128 tokens. The number of tokens can be customized and you can see more details on the [Solve GLUE tasks using BERT on a TPU colab](https://www.tensorflow.org/tutorials/text/solve_glue_tasks_using_bert_on_tpu).\n",
|
| 662 |
+
"- The `input_type_ids` only have one value (0) because this is a single sentence input. For a multiple sentence input, it would have one number for each input.\n",
|
| 663 |
+
"\n",
|
| 664 |
+
"Since this text preprocessor is a TensorFlow model, It can be included in your model directly."
|
| 665 |
+
]
|
| 666 |
+
},
|
| 667 |
+
{
|
| 668 |
+
"cell_type": "markdown",
|
| 669 |
+
"metadata": {
|
| 670 |
+
"id": "DKnLPSEmtp9i"
|
| 671 |
+
},
|
| 672 |
+
"source": [
|
| 673 |
+
"## Using the BERT model\n",
|
| 674 |
+
"\n",
|
| 675 |
+
"Before putting BERT into an own model, let's take a look at its outputs. You will load it from TF Hub and see the returned values."
|
| 676 |
+
]
|
| 677 |
+
},
|
| 678 |
+
{
|
| 679 |
+
"cell_type": "code",
|
| 680 |
+
"execution_count": 18,
|
| 681 |
+
"metadata": {
|
| 682 |
+
"execution": {
|
| 683 |
+
"iopub.execute_input": "2021-01-13T03:08:10.658519Z",
|
| 684 |
+
"iopub.status.busy": "2021-01-13T03:08:10.657556Z",
|
| 685 |
+
"iopub.status.idle": "2021-01-13T03:08:19.674983Z",
|
| 686 |
+
"shell.execute_reply": "2021-01-13T03:08:19.675465Z"
|
| 687 |
+
},
|
| 688 |
+
"id": "tXxYpK8ixL34"
|
| 689 |
+
},
|
| 690 |
+
"outputs": [],
|
| 691 |
+
"source": [
|
| 692 |
+
"bert_model = hub.KerasLayer(tfhub_handle_encoder)"
|
| 693 |
+
]
|
| 694 |
+
},
|
| 695 |
+
{
|
| 696 |
+
"cell_type": "code",
|
| 697 |
+
"execution_count": 19,
|
| 698 |
+
"metadata": {
|
| 699 |
+
"execution": {
|
| 700 |
+
"iopub.execute_input": "2021-01-13T03:08:19.682552Z",
|
| 701 |
+
"iopub.status.busy": "2021-01-13T03:08:19.681441Z",
|
| 702 |
+
"iopub.status.idle": "2021-01-13T03:08:20.297383Z",
|
| 703 |
+
"shell.execute_reply": "2021-01-13T03:08:20.297932Z"
|
| 704 |
+
},
|
| 705 |
+
"id": "_OoF9mebuSZc"
|
| 706 |
+
},
|
| 707 |
+
"outputs": [
|
| 708 |
+
{
|
| 709 |
+
"name": "stdout",
|
| 710 |
+
"output_type": "stream",
|
| 711 |
+
"text": [
|
| 712 |
+
"Loaded BERT: https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1\n",
|
| 713 |
+
"Pooled Outputs Shape:(1, 512)\n",
|
| 714 |
+
"Pooled Outputs Values:[-0.04969434 -0.16525201 -0.99807066 -0.93279284 -0.6145217 -0.22613084\n",
|
| 715 |
+
" -0.95588505 -0.50678337 0.29122898 0.2631647 0.7982282 0.49405995]\n",
|
| 716 |
+
"Sequence Outputs Shape:(1, 128, 512)\n",
|
| 717 |
+
"Sequence Outputs Values:[[-0.1024768 0.22204846 0.59883934 ... -0.25584042 0.61985433\n",
|
| 718 |
+
" -0.01822574]\n",
|
| 719 |
+
" [ 0.4550366 -0.57238305 0.5542101 ... -0.28608793 1.3628979\n",
|
| 720 |
+
" 0.9131196 ]\n",
|
| 721 |
+
" [ 0.42473704 0.29045174 0.82693 ... 0.28371704 1.7948036\n",
|
| 722 |
+
" -0.36674204]\n",
|
| 723 |
+
" ...\n",
|
| 724 |
+
" [-0.46153253 0.02829356 0.51673454 ... -0.15035403 1.4651561\n",
|
| 725 |
+
" 0.6449582 ]\n",
|
| 726 |
+
" [ 0.7110826 1.0848484 0.66065294 ... 0.4794111 0.723307\n",
|
| 727 |
+
" -0.08312207]\n",
|
| 728 |
+
" [ 0.35558882 -0.3890488 0.5101847 ... 0.19970936 0.86474574\n",
|
| 729 |
+
" 0.12227032]]\n"
|
| 730 |
+
]
|
| 731 |
+
}
|
| 732 |
+
],
|
| 733 |
+
"source": [
|
| 734 |
+
"bert_results = bert_model(text_preprocessed)\n",
|
| 735 |
+
"\n",
|
| 736 |
+
"print(f'Loaded BERT: {tfhub_handle_encoder}')\n",
|
| 737 |
+
"print(f'Pooled Outputs Shape:{bert_results[\"pooled_output\"].shape}')\n",
|
| 738 |
+
"print(f'Pooled Outputs Values:{bert_results[\"pooled_output\"][0, :12]}')\n",
|
| 739 |
+
"print(f'Sequence Outputs Shape:{bert_results[\"sequence_output\"].shape}')\n",
|
| 740 |
+
"print(f'Sequence Outputs Values:{bert_results[\"sequence_output\"][0, :12]}')"
|
| 741 |
+
]
|
| 742 |
+
},
|
| 743 |
+
{
|
| 744 |
+
"cell_type": "markdown",
|
| 745 |
+
"metadata": {
|
| 746 |
+
"id": "sm61jDrezAll"
|
| 747 |
+
},
|
| 748 |
+
"source": [
|
| 749 |
+
"The BERT models return a map with 3 important keys: `pooled_output`, `sequence_output`, `encoder_outputs`:\n",
|
| 750 |
+
"\n",
|
| 751 |
+
"- `pooled_output` to represent each input sequence as a whole. The shape is `[batch_size, H]`. You can think of this as an embedding for the entire movie review.\n",
|
| 752 |
+
"- `sequence_output` represents each input token in the context. The shape is `[batch_size, seq_length, H]`. You can think of this as a contextual embedding for every token in the movie review.\n",
|
| 753 |
+
"- `encoder_outputs` are the intermediate activations of the `L` Transformer blocks. `outputs[\"encoder_outputs\"][i]` is a Tensor of shape `[batch_size, seq_length, 1024]` with the outputs of the i-th Transformer block, for `0 <= i < L`. The last value of the list is equal to `sequence_output`.\n",
|
| 754 |
+
"\n",
|
| 755 |
+
"For the fine-tuning you are going to use the `pooled_output` array."
|
| 756 |
+
]
|
| 757 |
+
},
|
| 758 |
+
{
|
| 759 |
+
"cell_type": "markdown",
|
| 760 |
+
"metadata": {
|
| 761 |
+
"id": "pDNKfAXbDnJH"
|
| 762 |
+
},
|
| 763 |
+
"source": [
|
| 764 |
+
"## Define your model\n",
|
| 765 |
+
"\n",
|
| 766 |
+
"You will create a very simple fine-tuned model, with the preprocessing model, the selected BERT model, one Dense and a Dropout layer.\n",
|
| 767 |
+
"\n",
|
| 768 |
+
"Note: for more information about the base model's input and output you can use just follow the model's url for documentation. Here specifically you don't need to worry about it because the preprocessing model will take care of that for you.\n"
|
| 769 |
+
]
|
| 770 |
+
},
|
| 771 |
+
{
|
| 772 |
+
"cell_type": "code",
|
| 773 |
+
"execution_count": 20,
|
| 774 |
+
"metadata": {
|
| 775 |
+
"execution": {
|
| 776 |
+
"iopub.execute_input": "2021-01-13T03:08:20.306302Z",
|
| 777 |
+
"iopub.status.busy": "2021-01-13T03:08:20.305016Z",
|
| 778 |
+
"iopub.status.idle": "2021-01-13T03:08:20.307988Z",
|
| 779 |
+
"shell.execute_reply": "2021-01-13T03:08:20.307291Z"
|
| 780 |
+
},
|
| 781 |
+
"id": "aksj743St9ga"
|
| 782 |
+
},
|
| 783 |
+
"outputs": [],
|
| 784 |
+
"source": [
|
| 785 |
+
"def build_classifier_model():\n",
|
| 786 |
+
" text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')\n",
|
| 787 |
+
" preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')\n",
|
| 788 |
+
" encoder_inputs = preprocessing_layer(text_input)\n",
|
| 789 |
+
" encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')\n",
|
| 790 |
+
" outputs = encoder(encoder_inputs)\n",
|
| 791 |
+
" net = outputs['pooled_output']\n",
|
| 792 |
+
" net = tf.keras.layers.Dropout(0.1)(net)\n",
|
| 793 |
+
" net = tf.keras.layers.Dense(7, activation=None, name='classifier')(net)\n",
|
| 794 |
+
" return tf.keras.Model(text_input, net)"
|
| 795 |
+
]
|
| 796 |
+
},
|
| 797 |
+
{
|
| 798 |
+
"cell_type": "markdown",
|
| 799 |
+
"metadata": {
|
| 800 |
+
"id": "Zs4yhFraBuGQ"
|
| 801 |
+
},
|
| 802 |
+
"source": [
|
| 803 |
+
"Let's check that the model runs with the output of the preprocessing model."
|
| 804 |
+
]
|
| 805 |
+
},
|
| 806 |
+
{
|
| 807 |
+
"cell_type": "code",
|
| 808 |
+
"execution_count": 21,
|
| 809 |
+
"metadata": {
|
| 810 |
+
"execution": {
|
| 811 |
+
"iopub.execute_input": "2021-01-13T03:08:20.317706Z",
|
| 812 |
+
"iopub.status.busy": "2021-01-13T03:08:20.316292Z",
|
| 813 |
+
"iopub.status.idle": "2021-01-13T03:08:27.279676Z",
|
| 814 |
+
"shell.execute_reply": "2021-01-13T03:08:27.279042Z"
|
| 815 |
+
},
|
| 816 |
+
"id": "mGMF8AZcB2Zy"
|
| 817 |
+
},
|
| 818 |
+
"outputs": [
|
| 819 |
+
{
|
| 820 |
+
"name": "stdout",
|
| 821 |
+
"output_type": "stream",
|
| 822 |
+
"text": [
|
| 823 |
+
"tf.Tensor(\n",
|
| 824 |
+
"[[0.15770487 0.21116826 0.03859028 0.04381749 0.03109054 0.0646299\n",
|
| 825 |
+
" 0.4529987 ]], shape=(1, 7), dtype=float32)\n"
|
| 826 |
+
]
|
| 827 |
+
}
|
| 828 |
+
],
|
| 829 |
+
"source": [
|
| 830 |
+
"classifier_model = build_classifier_model()\n",
|
| 831 |
+
"bert_raw_result = classifier_model(tf.constant(trainfeatures[0]))\n",
|
| 832 |
+
"print(tf.keras.activations.softmax(bert_raw_result))"
|
| 833 |
+
]
|
| 834 |
+
},
|
| 835 |
+
{
|
| 836 |
+
"cell_type": "markdown",
|
| 837 |
+
"metadata": {
|
| 838 |
+
"id": "ZTUzNV2JE2G3"
|
| 839 |
+
},
|
| 840 |
+
"source": [
|
| 841 |
+
"The output is meaningless, of course, because the model has not been trained yet.\n",
|
| 842 |
+
"\n",
|
| 843 |
+
"Let's take a look at the model's structure."
|
| 844 |
+
]
|
| 845 |
+
},
|
| 846 |
+
{
|
| 847 |
+
"cell_type": "code",
|
| 848 |
+
"execution_count": 22,
|
| 849 |
+
"metadata": {},
|
| 850 |
+
"outputs": [
|
| 851 |
+
{
|
| 852 |
+
"name": "stdout",
|
| 853 |
+
"output_type": "stream",
|
| 854 |
+
"text": [
|
| 855 |
+
"Model: \"model\"\n",
|
| 856 |
+
"__________________________________________________________________________________________________\n",
|
| 857 |
+
" Layer (type) Output Shape Param # Connected to \n",
|
| 858 |
+
"==================================================================================================\n",
|
| 859 |
+
" text (InputLayer) [(None,)] 0 [] \n",
|
| 860 |
+
" \n",
|
| 861 |
+
" preprocessing (KerasLayer) {'input_type_ids': 0 ['text[0][0]'] \n",
|
| 862 |
+
" (None, 128), \n",
|
| 863 |
+
" 'input_mask': (Non \n",
|
| 864 |
+
" e, 128), \n",
|
| 865 |
+
" 'input_word_ids': \n",
|
| 866 |
+
" (None, 128)} \n",
|
| 867 |
+
" \n",
|
| 868 |
+
" BERT_encoder (KerasLayer) {'encoder_outputs': 41373185 ['preprocessing[0][0]', \n",
|
| 869 |
+
" [(None, 128, 512), 'preprocessing[0][1]', \n",
|
| 870 |
+
" (None, 128, 512), 'preprocessing[0][2]'] \n",
|
| 871 |
+
" (None, 128, 512), \n",
|
| 872 |
+
" (None, 128, 512), \n",
|
| 873 |
+
" (None, 128, 512), \n",
|
| 874 |
+
" (None, 128, 512), \n",
|
| 875 |
+
" (None, 128, 512), \n",
|
| 876 |
+
" (None, 128, 512)], \n",
|
| 877 |
+
" 'default': (None, \n",
|
| 878 |
+
" 512), \n",
|
| 879 |
+
" 'sequence_output': \n",
|
| 880 |
+
" (None, 128, 512), \n",
|
| 881 |
+
" 'pooled_output': ( \n",
|
| 882 |
+
" None, 512)} \n",
|
| 883 |
+
" \n",
|
| 884 |
+
" dropout (Dropout) (None, 512) 0 ['BERT_encoder[0][9]'] \n",
|
| 885 |
+
" \n",
|
| 886 |
+
" classifier (Dense) (None, 7) 3591 ['dropout[0][0]'] \n",
|
| 887 |
+
" \n",
|
| 888 |
+
"==================================================================================================\n",
|
| 889 |
+
"Total params: 41,376,776\n",
|
| 890 |
+
"Trainable params: 41,376,775\n",
|
| 891 |
+
"Non-trainable params: 1\n",
|
| 892 |
+
"__________________________________________________________________________________________________\n"
|
| 893 |
+
]
|
| 894 |
+
}
|
| 895 |
+
],
|
| 896 |
+
"source": [
|
| 897 |
+
"classifier_model.summary()"
|
| 898 |
+
]
|
| 899 |
+
},
|
| 900 |
+
{
|
| 901 |
+
"cell_type": "markdown",
|
| 902 |
+
"metadata": {
|
| 903 |
+
"id": "WbUWoZMwc302"
|
| 904 |
+
},
|
| 905 |
+
"source": [
|
| 906 |
+
"## Model training\n",
|
| 907 |
+
"\n",
|
| 908 |
+
"You now have all the pieces to train a model, including the preprocessing module, BERT encoder, data, and classifier."
|
| 909 |
+
]
|
| 910 |
+
},
|
| 911 |
+
{
|
| 912 |
+
"cell_type": "markdown",
|
| 913 |
+
"metadata": {
|
| 914 |
+
"id": "WpJ3xcwDT56v"
|
| 915 |
+
},
|
| 916 |
+
"source": [
|
| 917 |
+
"Since this is a non-binary classification problem and the model outputs probabilities, you'll use `losses.CategoricalCrossentropy` loss function.\n"
|
| 918 |
+
]
|
| 919 |
+
},
|
| 920 |
+
{
|
| 921 |
+
"cell_type": "code",
|
| 922 |
+
"execution_count": 23,
|
| 923 |
+
"metadata": {
|
| 924 |
+
"execution": {
|
| 925 |
+
"iopub.execute_input": "2021-01-13T03:08:27.596402Z",
|
| 926 |
+
"iopub.status.busy": "2021-01-13T03:08:27.595622Z",
|
| 927 |
+
"iopub.status.idle": "2021-01-13T03:08:27.600436Z",
|
| 928 |
+
"shell.execute_reply": "2021-01-13T03:08:27.600889Z"
|
| 929 |
+
},
|
| 930 |
+
"id": "OWPOZE-L3AgE"
|
| 931 |
+
},
|
| 932 |
+
"outputs": [],
|
| 933 |
+
"source": [
|
| 934 |
+
"loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n",
|
| 935 |
+
"metrics = tf.metrics.CategoricalAccuracy()"
|
| 936 |
+
]
|
| 937 |
+
},
|
| 938 |
+
{
|
| 939 |
+
"cell_type": "markdown",
|
| 940 |
+
"metadata": {
|
| 941 |
+
"id": "SqlarlpC_v0g"
|
| 942 |
+
},
|
| 943 |
+
"source": [
|
| 944 |
+
"### Loading the BERT model and training\n",
|
| 945 |
+
"\n",
|
| 946 |
+
"Using the `classifier_model` you created earlier, you can compile the model with the loss, metric and optimizer."
|
| 947 |
+
]
|
| 948 |
+
},
|
| 949 |
+
{
|
| 950 |
+
"cell_type": "code",
|
| 951 |
+
"execution_count": 24,
|
| 952 |
+
"metadata": {
|
| 953 |
+
"execution": {
|
| 954 |
+
"iopub.execute_input": "2021-01-13T03:08:27.621858Z",
|
| 955 |
+
"iopub.status.busy": "2021-01-13T03:08:27.621180Z",
|
| 956 |
+
"iopub.status.idle": "2021-01-13T03:08:27.631397Z",
|
| 957 |
+
"shell.execute_reply": "2021-01-13T03:08:27.631841Z"
|
| 958 |
+
},
|
| 959 |
+
"id": "-7GPDhR98jsD"
|
| 960 |
+
},
|
| 961 |
+
"outputs": [],
|
| 962 |
+
"source": [
|
| 963 |
+
"epochs=5\n",
|
| 964 |
+
"optimizer=tf.keras.optimizers.Adam(1e-5)\n",
|
| 965 |
+
"classifier_model.compile(optimizer=optimizer,\n",
|
| 966 |
+
" loss=loss,\n",
|
| 967 |
+
" metrics=metrics)"
|
| 968 |
+
]
|
| 969 |
+
},
|
| 970 |
+
{
|
| 971 |
+
"cell_type": "markdown",
|
| 972 |
+
"metadata": {
|
| 973 |
+
"id": "CpBuV5j2cS_b"
|
| 974 |
+
},
|
| 975 |
+
"source": [
|
| 976 |
+
"Note: training time will vary depending on the complexity of the BERT model you have selected."
|
| 977 |
+
]
|
| 978 |
+
},
|
| 979 |
+
{
|
| 980 |
+
"cell_type": "code",
|
| 981 |
+
"execution_count": 25,
|
| 982 |
+
"metadata": {
|
| 983 |
+
"execution": {
|
| 984 |
+
"iopub.execute_input": "2021-01-13T03:08:27.636326Z",
|
| 985 |
+
"iopub.status.busy": "2021-01-13T03:08:27.635519Z",
|
| 986 |
+
"iopub.status.idle": "2021-01-13T03:15:50.893395Z",
|
| 987 |
+
"shell.execute_reply": "2021-01-13T03:15:50.893900Z"
|
| 988 |
+
},
|
| 989 |
+
"id": "HtfDFAnN_Neu",
|
| 990 |
+
"scrolled": true
|
| 991 |
+
},
|
| 992 |
+
"outputs": [
|
| 993 |
+
{
|
| 994 |
+
"name": "stdout",
|
| 995 |
+
"output_type": "stream",
|
| 996 |
+
"text": [
|
| 997 |
+
"Training model with https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1\n",
|
| 998 |
+
"Epoch 1/5\n",
|
| 999 |
+
" 11/409 [..............................] - ETA: 1:40:36 - loss: 1.9467 - categorical_accuracy: 0.2159"
|
| 1000 |
+
]
|
| 1001 |
+
}
|
| 1002 |
+
],
|
| 1003 |
+
"source": [
|
| 1004 |
+
"print(f'Training model with {tfhub_handle_encoder}')\n",
|
| 1005 |
+
"history = classifier_model.fit(x=trainfeatures,y=trainlabels,\n",
|
| 1006 |
+
" validation_data=(validfeatures,validlabels),\n",
|
| 1007 |
+
" batch_size=32,\n",
|
| 1008 |
+
" epochs=epochs)"
|
| 1009 |
+
]
|
| 1010 |
+
},
|
| 1011 |
+
{
|
| 1012 |
+
"cell_type": "markdown",
|
| 1013 |
+
"metadata": {
|
| 1014 |
+
"id": "uBthMlTSV8kn"
|
| 1015 |
+
},
|
| 1016 |
+
"source": [
|
| 1017 |
+
"### Evaluate the model\n",
|
| 1018 |
+
"\n",
|
| 1019 |
+
"Let's see how the model performs. Two values will be returned. Loss (a number which represents the error, lower values are better), and accuracy."
|
| 1020 |
+
]
|
| 1021 |
+
},
|
| 1022 |
+
{
|
| 1023 |
+
"cell_type": "code",
|
| 1024 |
+
"execution_count": null,
|
| 1025 |
+
"metadata": {
|
| 1026 |
+
"execution": {
|
| 1027 |
+
"iopub.execute_input": "2021-01-13T03:15:50.898967Z",
|
| 1028 |
+
"iopub.status.busy": "2021-01-13T03:15:50.898282Z",
|
| 1029 |
+
"iopub.status.idle": "2021-01-13T03:16:53.613847Z",
|
| 1030 |
+
"shell.execute_reply": "2021-01-13T03:16:53.614285Z"
|
| 1031 |
+
},
|
| 1032 |
+
"id": "slqB-urBV9sP"
|
| 1033 |
+
},
|
| 1034 |
+
"outputs": [],
|
| 1035 |
+
"source": [
|
| 1036 |
+
"loss, accuracy = classifier_model.evaluate(testfeatures,testlabels)\n",
|
| 1037 |
+
"\n",
|
| 1038 |
+
"print(f'Loss: {loss}')\n",
|
| 1039 |
+
"print(f'Accuracy: {accuracy}')"
|
| 1040 |
+
]
|
| 1041 |
+
},
|
| 1042 |
+
{
|
| 1043 |
+
"cell_type": "markdown",
|
| 1044 |
+
"metadata": {
|
| 1045 |
+
"id": "uttWpgmSfzq9"
|
| 1046 |
+
},
|
| 1047 |
+
"source": [
|
| 1048 |
+
"### Plot the accuracy and loss over time\n",
|
| 1049 |
+
"\n",
|
| 1050 |
+
"Based on the `History` object returned by `model.fit()`. You can plot the training and validation loss for comparison, as well as the training and validation accuracy:"
|
| 1051 |
+
]
|
| 1052 |
+
},
|
| 1053 |
+
{
|
| 1054 |
+
"cell_type": "code",
|
| 1055 |
+
"execution_count": null,
|
| 1056 |
+
"metadata": {
|
| 1057 |
+
"execution": {
|
| 1058 |
+
"iopub.execute_input": "2021-01-13T03:16:53.641616Z",
|
| 1059 |
+
"iopub.status.busy": "2021-01-13T03:16:53.638634Z",
|
| 1060 |
+
"iopub.status.idle": "2021-01-13T03:16:53.950276Z",
|
| 1061 |
+
"shell.execute_reply": "2021-01-13T03:16:53.950768Z"
|
| 1062 |
+
},
|
| 1063 |
+
"id": "fiythcODf0xo"
|
| 1064 |
+
},
|
| 1065 |
+
"outputs": [],
|
| 1066 |
+
"source": [
|
| 1067 |
+
"history_dict = history.history\n",
|
| 1068 |
+
"print(history_dict.keys())\n",
|
| 1069 |
+
"\n",
|
| 1070 |
+
"acc = history_dict['categorical_accuracy']\n",
|
| 1071 |
+
"val_acc = history_dict['val_categorical_accuracy']\n",
|
| 1072 |
+
"loss = history_dict['loss']\n",
|
| 1073 |
+
"val_loss = history_dict['val_loss']\n",
|
| 1074 |
+
"\n",
|
| 1075 |
+
"epochs = range(1, len(acc) + 1)\n",
|
| 1076 |
+
"fig = plt.figure(figsize=(10, 8))\n",
|
| 1077 |
+
"fig.tight_layout()\n",
|
| 1078 |
+
"\n",
|
| 1079 |
+
"plt.subplot(2, 1, 1)\n",
|
| 1080 |
+
"# \"bo\" is for \"blue dot\"\n",
|
| 1081 |
+
"plt.plot(epochs, loss, 'r', label='Training loss')\n",
|
| 1082 |
+
"# b is for \"solid blue line\"\n",
|
| 1083 |
+
"plt.plot(epochs, val_loss, 'b', label='Validation loss')\n",
|
| 1084 |
+
"plt.title('Training and validation loss')\n",
|
| 1085 |
+
"plt.grid(True)\n",
|
| 1086 |
+
"# plt.xlabel('Epochs')\n",
|
| 1087 |
+
"plt.ylabel('Loss')\n",
|
| 1088 |
+
"plt.legend()\n",
|
| 1089 |
+
"\n",
|
| 1090 |
+
"plt.subplot(2, 1, 2)\n",
|
| 1091 |
+
"plt.plot(epochs, acc, 'r', label='Training acc')\n",
|
| 1092 |
+
"plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
|
| 1093 |
+
"plt.title('Training and validation accuracy')\n",
|
| 1094 |
+
"plt.grid(True)\n",
|
| 1095 |
+
"plt.xlabel('Epochs')\n",
|
| 1096 |
+
"plt.ylabel('Accuracy')\n",
|
| 1097 |
+
"plt.legend(loc='lower right')"
|
| 1098 |
+
]
|
| 1099 |
+
},
|
| 1100 |
+
{
|
| 1101 |
+
"cell_type": "markdown",
|
| 1102 |
+
"metadata": {
|
| 1103 |
+
"id": "WzJZCo-cf-Jf"
|
| 1104 |
+
},
|
| 1105 |
+
"source": [
|
| 1106 |
+
"In this plot, the red lines represents the training loss and accuracy, and the blue lines are the validation loss and accuracy."
|
| 1107 |
+
]
|
| 1108 |
+
},
|
| 1109 |
+
{
|
| 1110 |
+
"cell_type": "markdown",
|
| 1111 |
+
"metadata": {
|
| 1112 |
+
"id": "oyTappHTvNCz"
|
| 1113 |
+
},
|
| 1114 |
+
"source": [
|
| 1115 |
+
"Classifying arbitrary instructions:"
|
| 1116 |
+
]
|
| 1117 |
+
},
|
| 1118 |
+
{
|
| 1119 |
+
"cell_type": "code",
|
| 1120 |
+
"execution_count": null,
|
| 1121 |
+
"metadata": {
|
| 1122 |
+
"execution": {
|
| 1123 |
+
"iopub.execute_input": "2021-01-13T03:17:08.070832Z",
|
| 1124 |
+
"iopub.status.busy": "2021-01-13T03:17:08.068184Z",
|
| 1125 |
+
"iopub.status.idle": "2021-01-13T03:17:08.879352Z",
|
| 1126 |
+
"shell.execute_reply": "2021-01-13T03:17:08.878815Z"
|
| 1127 |
+
},
|
| 1128 |
+
"id": "VBWzH6exlCPS"
|
| 1129 |
+
},
|
| 1130 |
+
"outputs": [],
|
| 1131 |
+
"source": [
|
| 1132 |
+
"def print_my_examples(inputs, results):\n",
|
| 1133 |
+
" result_for_printing = \\\n",
|
| 1134 |
+
" [f'input: {inputs[i]:<30} : estimated intent: {results[i]}'\n",
|
| 1135 |
+
" for i in range(len(inputs))]\n",
|
| 1136 |
+
" print(*result_for_printing, sep='\\n')\n",
|
| 1137 |
+
" print()\n",
|
| 1138 |
+
"\n",
|
| 1139 |
+
"\n",
|
| 1140 |
+
"examples = [\n",
|
| 1141 |
+
" 'play a song from U2', # this is the same sentence tried earlier\n",
|
| 1142 |
+
" 'Will it rain tomorrow',\n",
|
| 1143 |
+
" 'I like to hear greatist hits from beastie boys',\n",
|
| 1144 |
+
" 'I like to book a table for 3 persons',\n",
|
| 1145 |
+
" '5 stars for machines like me'\n",
|
| 1146 |
+
"]\n",
|
| 1147 |
+
"\n",
|
| 1148 |
+
"results = tf.nn.softmax(classifier_model(tf.constant(examples)))"
|
| 1149 |
+
]
|
| 1150 |
+
},
|
| 1151 |
+
{
|
| 1152 |
+
"cell_type": "code",
|
| 1153 |
+
"execution_count": null,
|
| 1154 |
+
"metadata": {},
|
| 1155 |
+
"outputs": [],
|
| 1156 |
+
"source": [
|
| 1157 |
+
"binarizer.classes_"
|
| 1158 |
+
]
|
| 1159 |
+
},
|
| 1160 |
+
{
|
| 1161 |
+
"cell_type": "code",
|
| 1162 |
+
"execution_count": null,
|
| 1163 |
+
"metadata": {},
|
| 1164 |
+
"outputs": [],
|
| 1165 |
+
"source": [
|
| 1166 |
+
"intents=binarizer.inverse_transform(results.numpy())"
|
| 1167 |
+
]
|
| 1168 |
+
},
|
| 1169 |
+
{
|
| 1170 |
+
"cell_type": "code",
|
| 1171 |
+
"execution_count": null,
|
| 1172 |
+
"metadata": {
|
| 1173 |
+
"execution": {
|
| 1174 |
+
"iopub.execute_input": "2021-01-13T03:17:08.070832Z",
|
| 1175 |
+
"iopub.status.busy": "2021-01-13T03:17:08.068184Z",
|
| 1176 |
+
"iopub.status.idle": "2021-01-13T03:17:08.879352Z",
|
| 1177 |
+
"shell.execute_reply": "2021-01-13T03:17:08.878815Z"
|
| 1178 |
+
},
|
| 1179 |
+
"id": "VBWzH6exlCPS"
|
| 1180 |
+
},
|
| 1181 |
+
"outputs": [],
|
| 1182 |
+
"source": [
|
| 1183 |
+
"print_my_examples(examples, intents)"
|
| 1184 |
+
]
|
| 1185 |
+
},
|
| 1186 |
+
{
|
| 1187 |
+
"cell_type": "code",
|
| 1188 |
+
"execution_count": null,
|
| 1189 |
+
"metadata": {},
|
| 1190 |
+
"outputs": [],
|
| 1191 |
+
"source": []
|
| 1192 |
+
}
|
| 1193 |
+
],
|
| 1194 |
+
"metadata": {
|
| 1195 |
+
"accelerator": "GPU",
|
| 1196 |
+
"colab": {
|
| 1197 |
+
"collapsed_sections": [],
|
| 1198 |
+
"name": "classify_text_with_bert.ipynb",
|
| 1199 |
+
"toc_visible": true
|
| 1200 |
+
},
|
| 1201 |
+
"kernelspec": {
|
| 1202 |
+
"display_name": "Python 3.8.10 ('caesarainl': venv)",
|
| 1203 |
+
"language": "python",
|
| 1204 |
+
"name": "python3"
|
| 1205 |
+
},
|
| 1206 |
+
"language_info": {
|
| 1207 |
+
"codemirror_mode": {
|
| 1208 |
+
"name": "ipython",
|
| 1209 |
+
"version": 3
|
| 1210 |
+
},
|
| 1211 |
+
"file_extension": ".py",
|
| 1212 |
+
"mimetype": "text/x-python",
|
| 1213 |
+
"name": "python",
|
| 1214 |
+
"nbconvert_exporter": "python",
|
| 1215 |
+
"pygments_lexer": "ipython3",
|
| 1216 |
+
"version": "3.8.10"
|
| 1217 |
+
},
|
| 1218 |
+
"toc": {
|
| 1219 |
+
"base_numbering": 1,
|
| 1220 |
+
"nav_menu": {},
|
| 1221 |
+
"number_sections": true,
|
| 1222 |
+
"sideBar": true,
|
| 1223 |
+
"skip_h1_title": false,
|
| 1224 |
+
"title_cell": "Table of Contents",
|
| 1225 |
+
"title_sidebar": "Contents",
|
| 1226 |
+
"toc_cell": false,
|
| 1227 |
+
"toc_position": {},
|
| 1228 |
+
"toc_section_display": true,
|
| 1229 |
+
"toc_window_display": false
|
| 1230 |
+
},
|
| 1231 |
+
"vscode": {
|
| 1232 |
+
"interpreter": {
|
| 1233 |
+
"hash": "405f94f3dc67f15d7addffc962b8172d0d4960d3524e541166097ab2d7a328e0"
|
| 1234 |
+
}
|
| 1235 |
+
}
|
| 1236 |
+
},
|
| 1237 |
+
"nbformat": 4,
|
| 1238 |
+
"nbformat_minor": 1
|
| 1239 |
+
}
|
CaesarAINL/caesar_tensorflow_install.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Caesar Tensorflow Install
|
| 2 |
+
# Tensorflow-gpu source versions
|
| 3 |
+
# https://www.tensorflow.org/install/source#gpu
|
| 4 |
+
|
| 5 |
+
# Use Youtube video
|
| 6 |
+
# MiniConda
|
| 7 |
+
# https://docs.conda.io/en/main/miniconda.html
|
| 8 |
+
# Cuda Toolkit
|
| 9 |
+
# https://developer.nvidia.com/cuda-11.2.2-download-archive?target_os=Windows&target_arch=x86_64&target_version=10&target_type=exenetwork
|
| 10 |
+
|
| 11 |
+
# Cuda Deep Neural Network cudnn
|
| 12 |
+
# https://developer.nvidia.com/rdp/cudnn-archive
|
CaesarAINL/caesarapis.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pyttsx3
|
| 3 |
+
engine=pyttsx3.init('sapi5')
|
| 4 |
+
voices=engine.getProperty('voices')
|
| 5 |
+
engine.setProperty('voice',voices[1].id)
|
| 6 |
+
|
| 7 |
+
def speak(text,whisper_mode=None):
|
| 8 |
+
if whisper_mode == 0:
|
| 9 |
+
engine.say(text)
|
| 10 |
+
engine.runAndWait()
|
| 11 |
+
|
| 12 |
+
class CaesarAPIs:
|
| 13 |
+
def __init__(self) -> None:
|
| 14 |
+
self.whisper_mode = 0
|
| 15 |
+
def runapis(self,caesarResponse=None,intent=None,userinput=None):
|
| 16 |
+
if intent == "whisper_mode" and "on" in userinput:
|
| 17 |
+
self.whisper_mode = 1
|
| 18 |
+
speak("I will be quiet now sir",0)
|
| 19 |
+
elif intent == "whisper_mode" and "off" in userinput:
|
| 20 |
+
self.whisper_mode = 0
|
| 21 |
+
speak("Can you hear me now sir",0)
|
| 22 |
+
#elif
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
CaesarAINL/caesarapis/caesarReminder.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import requests
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
class CaesarReminder:
|
| 7 |
+
@staticmethod
|
| 8 |
+
def reminder():
|
| 9 |
+
if "CaesarReminders" in os.listdir():
|
| 10 |
+
if "caesarreminders.json" in os.listdir("CaesarReminders"):
|
| 11 |
+
with open("CaesarReminders/caesarreminders.json","r") as f:
|
| 12 |
+
reminders = json.load(f)
|
| 13 |
+
message = ""
|
| 14 |
+
for reminder in reminders["reminders"]:
|
| 15 |
+
message += "{}".format(reminder['subject'])
|
| 16 |
+
message += "<br>"
|
| 17 |
+
message += "{}".format(reminder['message'])
|
| 18 |
+
message += "<br>"
|
| 19 |
+
message += "Reminder: {}\n".format(datetime.fromisoformat(reminder['timestep']).strftime('%m/%d/%Y, %H:%M:%S'))
|
| 20 |
+
message += "<br>"
|
| 21 |
+
message += "<br>"
|
| 22 |
+
sendjson = {"raspsendemail":{"email":reminders["email"],"message":message,"subject":"Caesar Reminders"}}
|
| 23 |
+
response = requests.post("https://revisionbank-email.onrender.com/raspsendemail",json=sendjson)
|
| 24 |
+
print(response.text)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
elif "caesarreminders.json" not in os.listdir("CaesarReminders"):
|
| 28 |
+
sendjson = {"raspsendemail":{"email":"amari.lawal@gmail.com","message":"No Reminders Scheduled","subject":"Caesar Reminders"}}
|
| 29 |
+
response = requests.post("https://revisionbank-email.onrender.com/raspsendemail",json=sendjson)
|
| 30 |
+
print(response.text)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# send email saying reminder
|
| 35 |
+
elif "CaesarReminders" not in os.listdir():
|
| 36 |
+
sendjson = {"raspsendemail":{"email":"amari.lawal@gmail.com","message":"No Reminders Scheduled","subject":"Caesar Reminders"}}
|
| 37 |
+
response = requests.post("https://revisionbank-email.onrender.com/raspsendemail",json=sendjson)
|
| 38 |
+
print(response.text)
|
| 39 |
+
# Send email saying No reminders
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
constant = 60 *60
|
| 46 |
+
duration = 48 * constant # hours
|
| 47 |
+
|
| 48 |
+
#print(datetime.now().isoformat())
|
| 49 |
+
while True:
|
| 50 |
+
CaesarReminder.reminder()
|
| 51 |
+
time.sleep(duration)
|
| 52 |
+
#pass
|
CaesarAINL/caesarbackground.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CaesarBackground
|
| 2 |
+
# start caesarbackground - sudo nohup /home/amari/Desktop/CaesarAI/CaesarAINL/caesarapis/caesarReminder.py > caesarReminder.out &
|
| 3 |
+
|
| 4 |
+
# kill nohup processes - pkill -f caesarReminder.py
|
| 5 |
+
# list nohup processes - ps ax | grep caesarReminder.py
|
CaesarAINL/caesarcomplete/berttest.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow_hub as hub
|
| 2 |
+
|
| 3 |
+
BERT_URL = 'https://tfhub.dev/google/bert_cased_L-12_H-768_A-12/1'
|
| 4 |
+
module = hub.Module(BERT_URL)
|
| 5 |
+
|
| 6 |
+
# Look at the descriptor. This would tell you the model name
|
| 7 |
+
# cat $TFHUB_CACHE_DIR/ecd2596ce849110246602e3d4d81e2d9719cb027.descriptor.txt
|
| 8 |
+
|
| 9 |
+
# Further look at the assets folder, this has the file `vocab.txt`
|
| 10 |
+
# ls $TFHUB_CACHE_DIR/ecd2596ce849110246602e3d4d81e2d9719cb027/assets
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import tokenization
|
| 14 |
+
|
| 15 |
+
def create_tokenizer(vocab_file, do_lower_case=False):
|
| 16 |
+
return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
|
| 17 |
+
|
| 18 |
+
tokenizer = create_tokenizer('vocab.txt', do_lower_case=False)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def convert_sentence_to_features(sentence, tokenizer, max_seq_len):
|
| 22 |
+
tokens = ['[CLS]']
|
| 23 |
+
tokens.extend(tokenizer.tokenize(sentence))
|
| 24 |
+
if len(tokens) > max_seq_len-1:
|
| 25 |
+
tokens = tokens[:max_seq_len-1]
|
| 26 |
+
tokens.append('[SEP]')
|
| 27 |
+
|
| 28 |
+
segment_ids = [0] * len(tokens)
|
| 29 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 30 |
+
input_mask = [1] * len(input_ids)
|
| 31 |
+
|
| 32 |
+
#Zero Mask till seq_length
|
| 33 |
+
zero_mask = [0] * (max_seq_len-len(tokens))
|
| 34 |
+
input_ids.extend(zero_mask)
|
| 35 |
+
input_mask.extend(zero_mask)
|
| 36 |
+
segment_ids.extend(zero_mask)
|
| 37 |
+
|
| 38 |
+
return input_ids, input_mask, segment_ids
|
| 39 |
+
|
| 40 |
+
def convert_sentences_to_features(sentences, tokenizer, max_seq_len=20):
|
| 41 |
+
all_input_ids = []
|
| 42 |
+
all_input_mask = []
|
| 43 |
+
all_segment_ids = []
|
| 44 |
+
|
| 45 |
+
for sentence in sentences:
|
| 46 |
+
input_ids, input_mask, segment_ids = convert_sentence_to_features(sentence, tokenizer, max_seq_len)
|
| 47 |
+
all_input_ids.append(input_ids)
|
| 48 |
+
all_input_mask.append(input_mask)
|
| 49 |
+
all_segment_ids.append(segment_ids)
|
| 50 |
+
|
| 51 |
+
return all_input_ids, all_input_mask, all_segment_ids
|
CaesarAINL/caesarcomplete/caesar_tensorflow_install.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Caesar Tensorflow Install
|
| 2 |
+
# Tensorflow-gpu source versions
|
| 3 |
+
# https://www.tensorflow.org/install/source#gpu
|
| 4 |
+
|
| 5 |
+
# Use Youtube video
|
| 6 |
+
# MiniConda
|
| 7 |
+
# https://docs.conda.io/en/main/miniconda.html
|
| 8 |
+
# Cuda Toolkit
|
| 9 |
+
# https://developer.nvidia.com/cuda-11.2.2-download-archive?target_os=Windows&target_arch=x86_64&target_version=10&target_type=exenetwork
|
| 10 |
+
|
| 11 |
+
# Cuda Deep Neural Network cudnn
|
| 12 |
+
# https://developer.nvidia.com/rdp/cudnn-archive
|
CaesarAINL/caesarcomplete/caesarapis.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
class CaesarAPIs:
|
| 3 |
+
def __init__(self) -> None:
|
| 4 |
+
self.whisper_mode = 0
|
| 5 |
+
def runapis(self,caesarResponse=None,intent=None,userinput=None,voiceengine=None):
|
| 6 |
+
if intent == "whisper_mode" and "on" in userinput:
|
| 7 |
+
self.whisper_mode = 1
|
| 8 |
+
voiceengine.say("I will be quiet now sir")
|
| 9 |
+
voiceengine.runAndWait()
|
| 10 |
+
elif intent == "whisper_mode" and "off" in userinput:
|
| 11 |
+
self.whisper_mode = 0
|
| 12 |
+
voiceengine.say("Can you hear me now sir")
|
| 13 |
+
voiceengine.runAndWait()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
CaesarAINL/caesarcomplete/caesarnlexamples.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import json
|
| 3 |
+
import spacy
|
| 4 |
+
import pickle
|
| 5 |
+
import random
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
import tensorflow_hub as hub
|
| 8 |
+
import tensorflow_text as text
|
| 9 |
+
from sklearn.preprocessing import LabelBinarizer
|
| 10 |
+
|
| 11 |
+
def print_my_examples(inputs, results):
|
| 12 |
+
result_for_printing = \
|
| 13 |
+
[f'input: {inputs[i]:<30} : estimated intent: {results[i]}'
|
| 14 |
+
for i in range(len(inputs))]
|
| 15 |
+
print(*result_for_printing, sep='\n')
|
| 16 |
+
print()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
examples = [
|
| 20 |
+
'play a song from U2', # this is the same sentence tried earlier
|
| 21 |
+
'Will it rain tomorrow',
|
| 22 |
+
'I like to hear greatist hits from beastie boys',
|
| 23 |
+
'I like to book a table for 3 persons',
|
| 24 |
+
'5 stars for machines like me',
|
| 25 |
+
'play a boogie wit da hoodie',
|
| 26 |
+
"play Bob's favorite song",
|
| 27 |
+
"give me a hug",
|
| 28 |
+
"hello"
|
| 29 |
+
]
|
| 30 |
+
greetings = ["Greeting","smalltalk_greetings_hello"]
|
| 31 |
+
courtesy_greeting = ["CourtesyGreeting"]
|
| 32 |
+
stored_name = "Amari"
|
| 33 |
+
examples = ["hello"]
|
| 34 |
+
#nlp = spacy.load("en_core_web_sm")
|
| 35 |
+
classifier_model = tf.keras.models.load_model('caesarmodel/caesarnl.h5',custom_objects={'KerasLayer':hub.KerasLayer})
|
| 36 |
+
|
| 37 |
+
# Show the model architecture
|
| 38 |
+
results = tf.nn.softmax(classifier_model(tf.constant(examples)))
|
| 39 |
+
with open("caesarmodel/labelbinarizer.pkl","rb") as f:
|
| 40 |
+
binarizer = pickle.load(f)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
intents=binarizer.inverse_transform(results.numpy())
|
| 44 |
+
with open("intentdata/responses.json","r") as f:
|
| 45 |
+
responses = json.load(f)["responses"]
|
| 46 |
+
|
| 47 |
+
if intents[0] in greetings:
|
| 48 |
+
greetresponse = random.choice(responses["Greeting"]).replace("<HUMAN>",stored_name)
|
| 49 |
+
print(greetresponse)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
#sentence_intents = dict(zip(examples,intents))
|
| 54 |
+
#print(sentence_intents)
|
| 55 |
+
#print_my_examples(examples, intents)
|
| 56 |
+
|
| 57 |
+
# TODO AIM - Implement Chatbot Gossip to Caesar
|
| 58 |
+
# 1. Add data to datasets train | valid | test
|
| 59 |
+
# a. then clean labels
|
| 60 |
+
# 2. Augment data to provide more potential possibilites
|
| 61 |
+
# 3. Use BERT to match input with the response
|
| 62 |
+
# Command Labels - AddToPlaylist | GetWeather -> API -> user
|
| 63 |
+
# Conversation Labes - Greeting | Goodbye -> BERTNN: input:"hello" => response:"hi there, I am caesar" -> user
|
| 64 |
+
|
| 65 |
+
# TODO AIM - Single names of songs artists like "play a boogie" and it will play a boogie's music.
|
| 66 |
+
# 1. Idea one - NER detect the named entities
|
| 67 |
+
# 2. Create new Neural Network that detects that. * Have to determine the relationship between the entites
|
| 68 |
+
|
| 69 |
+
|
CaesarAINL/caesarcomplete/data_aggregation.ipynb
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import pandas as pd\n",
|
| 10 |
+
"import json\n",
|
| 11 |
+
"smalltalkintent = pd.read_csv(\"intentdata/Small_talk_Intent.csv\").rename(columns={\"Utterances\":\"text\",\"Intent\":\"intent\"})\n",
|
| 12 |
+
"training = pd.read_csv(\"intentdata/train_command.csv\")\n",
|
| 13 |
+
"df3 = pd.concat([training,smalltalkintent], ignore_index=True)\n",
|
| 14 |
+
"df3.to_csv(\"intentdata/train.csv\",mode=\"w\",index=False)\n",
|
| 15 |
+
"# TODO Do NLPAugmentation\n",
|
| 16 |
+
"#df3 = pd.concat([training,smalltalkintent], ignore_index=True)\n",
|
| 17 |
+
"#with open(\"intentdata/Intent.json\",\"r\") as f:\n",
|
| 18 |
+
"# chatbotXresponseintent = json.load(f)\n",
|
| 19 |
+
"#chatbotXresponseintent[\"intents\"][0] \n"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"# Produce responses\n",
|
| 29 |
+
"import json \n",
|
| 30 |
+
"with open(\"intentdata/Intent.json\",\"r\") as f:\n",
|
| 31 |
+
" data = json.load(f)\n",
|
| 32 |
+
"intents = []\n",
|
| 33 |
+
"for intent in data[\"intents\"]:\n",
|
| 34 |
+
" intents.append({\"intent\":intent[\"intent\"],\"responses\":intent[\"responses\"]})\n",
|
| 35 |
+
"with open(\"intentdata/responses.json\",\"w+\") as f:\n",
|
| 36 |
+
" json.dump({\"response\":intents},f)\n",
|
| 37 |
+
" "
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": null,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"with open(\"intentdata/responses.json\",\"r\") as f:\n",
|
| 47 |
+
" responses = json.load(f)[\"responses\"]\n",
|
| 48 |
+
"#print(intents[0] in greetings)\n",
|
| 49 |
+
"#if intents[0] in greetings:\n",
|
| 50 |
+
"print(responses[\"Greeting\"])"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": null,
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"import json\n",
|
| 60 |
+
"import pandas as pd \n",
|
| 61 |
+
"from IPython.display import display\n",
|
| 62 |
+
"with open(\"intentdata/aug/intent_aug_text.json\",\"r\") as f:\n",
|
| 63 |
+
" data = json.load(f)[\"intents\"]\n",
|
| 64 |
+
"columns_to_concat = []\n",
|
| 65 |
+
"for i in range(len(data)):\n",
|
| 66 |
+
" column1 = pd.DataFrame.from_dict({\"text\":data[i][\"text\"]})\n",
|
| 67 |
+
" #print(column1)\n",
|
| 68 |
+
" column2 = pd.DataFrame.from_dict({\"intent\":[data[i][\"intent\"] for j in range(len(data[i][\"text\"]))]})\n",
|
| 69 |
+
" #print(column2)\n",
|
| 70 |
+
" tent_concat= pd.concat([column1,column2],axis=1)\n",
|
| 71 |
+
" #print(tent_concat)\n",
|
| 72 |
+
" columns_to_concat.append(tent_concat)\n",
|
| 73 |
+
"intentdf = pd.concat(columns_to_concat,axis=0)\n",
|
| 74 |
+
"intentdf\n",
|
| 75 |
+
"df = pd.read_csv(\"intentdata/train.csv\")\n",
|
| 76 |
+
"newdf = pd.concat([df,intentdf],axis=0)\n",
|
| 77 |
+
"newdf.to_csv(\"intentdata/train.csv\",mode=\"w\",index=False)\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"\n"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "code",
|
| 84 |
+
"execution_count": null,
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"source": [
|
| 88 |
+
"import pandas as pd\n",
|
| 89 |
+
"data = pd.read_csv(\"intentdata/train.csv\")\n"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": null,
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [],
|
| 97 |
+
"source": [
|
| 98 |
+
"import pandas as pd\n",
|
| 99 |
+
"data = pd.read_csv(\"intentdata/train_no_half_response.csv\") # Original data is a (2000, 7) DataFrame\n",
|
| 100 |
+
"data = data.replace(\"smalltalk_agent_acquaintance\",\"CourtesyGreeting\")\n",
|
| 101 |
+
"data = data.replace(\"smalltalk_agent_age\",\"CurrentHumanQuery\")\n",
|
| 102 |
+
"data = data.replace(\"smalltalk_agent_annoying\",\"NotTalking2U\")\n",
|
| 103 |
+
"data = data.replace(\"smalltalk_agent_bad\",\"Swearing\")\n",
|
| 104 |
+
"data = data.replace(\"smalltalk_agent_boss\",\"CurrentHumanQuery\")\n",
|
| 105 |
+
"data = data.replace(\"smalltalk_agent_clever\",\"Clever\")\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"data = data.replace(\"smalltalk_agent_beautiful\",\"Clever\")\n",
|
| 108 |
+
"data = data.replace(\"smalltalk_agent_fired\",\"Shutup\")\n",
|
| 109 |
+
"data = data.replace(\"smalltalk_agent_good\",\"Thanks\")\n",
|
| 110 |
+
"data = data.replace(\"smalltalk_agent_chatbot\",\"SelfAware\")\n",
|
| 111 |
+
"data = data.replace(\"smalltalk_agent_real\",\"SelfAware\")\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"data.to_csv(\"intentdata/train.csv\",mode=\"w\",index=False)"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "code",
|
| 119 |
+
"execution_count": null,
|
| 120 |
+
"metadata": {},
|
| 121 |
+
"outputs": [],
|
| 122 |
+
"source": [
|
| 123 |
+
"# TODO Checks label data balance\n",
|
| 124 |
+
"import pandas as pd\n",
|
| 125 |
+
"data = pd.read_csv(\"intentdata/train.csv\") # Original data is a (2000, 7) DataFrame\n",
|
| 126 |
+
"# data contains 6 feature columns and 1 target column.\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"# Separate the design matrix from the target labels.\n",
|
| 129 |
+
"X = data.iloc[:, :-1]\n",
|
| 130 |
+
"y = data['intent']\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"y.value_counts().sort_index().plot.bar(x='Target Value', y='Number of Occurrences',figsize=(20,20))"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": null,
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"outputs": [],
|
| 140 |
+
"source": []
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"cell_type": "code",
|
| 144 |
+
"execution_count": null,
|
| 145 |
+
"metadata": {},
|
| 146 |
+
"outputs": [],
|
| 147 |
+
"source": [
|
| 148 |
+
"import torch\n",
|
| 149 |
+
"torch.cuda.is_available()"
|
| 150 |
+
]
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"cell_type": "code",
|
| 154 |
+
"execution_count": null,
|
| 155 |
+
"metadata": {},
|
| 156 |
+
"outputs": [],
|
| 157 |
+
"source": [
|
| 158 |
+
"import json \n",
|
| 159 |
+
"import nlpaug.augmenter.word as naw\n",
|
| 160 |
+
"import nlpaug.flow as naf\n",
|
| 161 |
+
"print(\"Loading Models...\")\n",
|
| 162 |
+
"import nltk \n",
|
| 163 |
+
"nltk.download('wordnet')\n",
|
| 164 |
+
"nltk.download('omw-1.4')\n",
|
| 165 |
+
"TOPK=20 #default=100\n",
|
| 166 |
+
"ACT = 'insert' #\"substitute\"\n",
|
| 167 |
+
"def tst():\n",
|
| 168 |
+
" aug_w2v= naw.WordEmbsAug(\n",
|
| 169 |
+
" model_type='glove', model_path='glove/glove.6B.300d.txt',\n",
|
| 170 |
+
" action=\"substitute\")\n",
|
| 171 |
+
" aug_bert = naw.ContextualWordEmbsAug(\n",
|
| 172 |
+
" model_path='distilbert-base-uncased', \n",
|
| 173 |
+
" \n",
|
| 174 |
+
" action=ACT, top_k=TOPK)\n",
|
| 175 |
+
" aug = naf.Sequential([\n",
|
| 176 |
+
" aug_bert,aug_w2v\n",
|
| 177 |
+
" ])\n",
|
| 178 |
+
"print(\"Models Loaded.\")\n",
|
| 179 |
+
"with open(\"intentdata/intent.json\") as f:\n",
|
| 180 |
+
" intentwhole = json.load(f)[\"intents\"]\n",
|
| 181 |
+
"#text = intent[0][\"text\"][3]\n",
|
| 182 |
+
"for intent in intentwhole:\n",
|
| 183 |
+
" for text in intent[\"text\"]:\n",
|
| 184 |
+
" augmented_texts = set()\n",
|
| 185 |
+
" for i in range(20):\n",
|
| 186 |
+
" aug = naw.SynonymAug(aug_src='wordnet',aug_min=1, aug_max=10, aug_p=i/10)\n",
|
| 187 |
+
" augmented_text = str(aug.augment(text))\n",
|
| 188 |
+
" print(augmented_text)\n",
|
| 189 |
+
" #print(augmented_text)\n",
|
| 190 |
+
" augmented_texts.add(augmented_text)\n",
|
| 191 |
+
" augmented_texts = augmented_texts.union(set(intent[\"text\"]))\n",
|
| 192 |
+
" intent[\"text\"] = list(augmented_texts)\n",
|
| 193 |
+
" def test():\n",
|
| 194 |
+
" if intent[\"intent\"] != \"Jokes\": \n",
|
| 195 |
+
" for response in intent[\"responses\"]:\n",
|
| 196 |
+
" augmented_responses = set()\n",
|
| 197 |
+
" for i in range(20):\n",
|
| 198 |
+
" #aug = naw.SynonymAug(aug_src='wordnet',aug_min=1, aug_max=10, aug_p=i/50,stopwords=[\"<HUMAN>\",\"<HUMAN>,\",\"<HUMAN>!\"])\n",
|
| 199 |
+
" augmented_response = str(aug.augment(response)[0])\n",
|
| 200 |
+
" print(augmented_response)\n",
|
| 201 |
+
" try:\n",
|
| 202 |
+
" augmented_response = augmented_response[:augmented_response.index(\"<\")] + \"<HUMAN\" + augmented_response[augmented_response.index(\">\"):]\n",
|
| 203 |
+
" except ValueError as vex:\n",
|
| 204 |
+
" pass\n",
|
| 205 |
+
" #print(augmented_text)\n",
|
| 206 |
+
" augmented_responses.add(augmented_response)\n",
|
| 207 |
+
" augmented_responses = augmented_responses.union(set(intent[\"responses\"]))\n",
|
| 208 |
+
" intent[\"responses\"] = list(augmented_responses) \n",
|
| 209 |
+
"with open(\"intentdata/intent_aug_text_test.json\",\"w+\") as f:\n",
|
| 210 |
+
" json.dump({\"intents\":intentwhole},f)\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"print(intentwhole[1][\"responses\"])"
|
| 214 |
+
]
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"cell_type": "code",
|
| 218 |
+
"execution_count": null,
|
| 219 |
+
"metadata": {},
|
| 220 |
+
"outputs": [],
|
| 221 |
+
"source": [
|
| 222 |
+
"df = pd.read_csv(\"intentdata/train.csv\")\n",
|
| 223 |
+
"len(list(pd.unique(df[\"intent\"])))"
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"cell_type": "code",
|
| 228 |
+
"execution_count": null,
|
| 229 |
+
"metadata": {},
|
| 230 |
+
"outputs": [],
|
| 231 |
+
"source": [
|
| 232 |
+
"import os\n",
|
| 233 |
+
"import json \n",
|
| 234 |
+
"import pandas as pd\n",
|
| 235 |
+
"os.listdir(\"new_intent_data\")"
|
| 236 |
+
]
|
| 237 |
+
},
|
| 238 |
+
{
|
| 239 |
+
"cell_type": "code",
|
| 240 |
+
"execution_count": null,
|
| 241 |
+
"metadata": {},
|
| 242 |
+
"outputs": [],
|
| 243 |
+
"source": []
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"cell_type": "code",
|
| 247 |
+
"execution_count": null,
|
| 248 |
+
"metadata": {},
|
| 249 |
+
"outputs": [],
|
| 250 |
+
"source": [
|
| 251 |
+
"import spacy\n",
|
| 252 |
+
"\n",
|
| 253 |
+
"nlp = spacy.load(\"en_core_web_sm\")\n",
|
| 254 |
+
"nlp(\"start the fan\").similarity(nlp(\"turn the fan on\"))"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"cell_type": "code",
|
| 259 |
+
"execution_count": null,
|
| 260 |
+
"metadata": {},
|
| 261 |
+
"outputs": [],
|
| 262 |
+
"source": [
|
| 263 |
+
"import pandas as pd\n",
|
| 264 |
+
"traindf = pd.read_csv(\"intentdata/train.csv\")\n",
|
| 265 |
+
"traindf"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"cell_type": "code",
|
| 270 |
+
"execution_count": null,
|
| 271 |
+
"metadata": {},
|
| 272 |
+
"outputs": [],
|
| 273 |
+
"source": [
|
| 274 |
+
"import pandas as pd\n",
|
| 275 |
+
"intentdf = pd.read_csv(\"new_intent_data/intent_classification.csv\")\n",
|
| 276 |
+
"traindf = pd.read_csv(\"intentdata/train.csv\")\n",
|
| 277 |
+
"smart_home_intent = intentdf[[\"text-en\",\"intent-en\"]].rename({\"text-en\":\"text\",\"intent-en\":\"intent\"},axis=1)\n",
|
| 278 |
+
"newdf = pd.concat([traindf,smart_home_intent],axis=0)\n",
|
| 279 |
+
"\n",
|
| 280 |
+
"newdf.to_csv(\"intentdata/train.csv\",mode=\"w\",index=False)"
|
| 281 |
+
]
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"cell_type": "code",
|
| 285 |
+
"execution_count": null,
|
| 286 |
+
"metadata": {},
|
| 287 |
+
"outputs": [],
|
| 288 |
+
"source": [
|
| 289 |
+
"with open(\"new_intent_data/Dataset-train_old.txt\") as f:\n",
|
| 290 |
+
" textdata = f.readlines()\n",
|
| 291 |
+
"sent = \"\"\n",
|
| 292 |
+
"for text in textdata:\n",
|
| 293 |
+
" if text == \"\":\n",
|
| 294 |
+
" pass\n",
|
| 295 |
+
" elif text != \"\":\n",
|
| 296 |
+
" if text.count(\",\") == 2:\n",
|
| 297 |
+
" #print(text.replace(\",\",\";\",1))\n",
|
| 298 |
+
" sent += text.replace(\",\",\";\",1)\n",
|
| 299 |
+
" else:\n",
|
| 300 |
+
" \n",
|
| 301 |
+
" sent += text\n",
|
| 302 |
+
"print(sent)\n",
|
| 303 |
+
"with open(\"new_intent_data/Dataset-train.txt\",\"w+\") as f:\n",
|
| 304 |
+
" f.write(sent)\n",
|
| 305 |
+
" "
|
| 306 |
+
]
|
| 307 |
+
},
|
| 308 |
+
{
|
| 309 |
+
"cell_type": "code",
|
| 310 |
+
"execution_count": null,
|
| 311 |
+
"metadata": {},
|
| 312 |
+
"outputs": [],
|
| 313 |
+
"source": [
|
| 314 |
+
"import pandas as pd\n",
|
| 315 |
+
"data = pd.read_csv(\"new_intent_data/Dataset-train.csv\",delimiter=\",\") \n",
|
| 316 |
+
"data = data.drop(\"Unnamed: 0\",axis=1)\n",
|
| 317 |
+
"data\n",
|
| 318 |
+
"data.to_csv(\"new_intent_data/Dataset-train.csv\",index=False)"
|
| 319 |
+
]
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"cell_type": "code",
|
| 323 |
+
"execution_count": null,
|
| 324 |
+
"metadata": {},
|
| 325 |
+
"outputs": [],
|
| 326 |
+
"source": [
|
| 327 |
+
"import pandas as pd\n",
|
| 328 |
+
"data = pd.read_csv(\"new_intent_data/Dataset-train.csv\",delimiter=\",\") # Original data is a (2000, 7) DataFrame\n",
|
| 329 |
+
"data = data.apply(lambda x: x.replace('\"',''),axis=1)\n",
|
| 330 |
+
"data.to_csv(\"new_intent_data/Dataset-train.csv\")"
|
| 331 |
+
]
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"cell_type": "code",
|
| 335 |
+
"execution_count": 1,
|
| 336 |
+
"metadata": {},
|
| 337 |
+
"outputs": [],
|
| 338 |
+
"source": [
|
| 339 |
+
"import pandas as pd\n",
|
| 340 |
+
"data = pd.read_csv(\"new_intent_data/Dataset-train.csv\",delimiter=\"|\") \n",
|
| 341 |
+
"traindf = pd.read_csv(\"intentdata/train.csv\")\n",
|
| 342 |
+
"new_data = pd.concat([traindf,data],axis=0)\n",
|
| 343 |
+
"new_data.to_csv(\"intentdata/train.csv\",index=False)"
|
| 344 |
+
]
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"cell_type": "code",
|
| 348 |
+
"execution_count": 2,
|
| 349 |
+
"metadata": {},
|
| 350 |
+
"outputs": [
|
| 351 |
+
{
|
| 352 |
+
"data": {
|
| 353 |
+
"text/html": [
|
| 354 |
+
"<div>\n",
|
| 355 |
+
"<style scoped>\n",
|
| 356 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 357 |
+
" vertical-align: middle;\n",
|
| 358 |
+
" }\n",
|
| 359 |
+
"\n",
|
| 360 |
+
" .dataframe tbody tr th {\n",
|
| 361 |
+
" vertical-align: top;\n",
|
| 362 |
+
" }\n",
|
| 363 |
+
"\n",
|
| 364 |
+
" .dataframe thead th {\n",
|
| 365 |
+
" text-align: right;\n",
|
| 366 |
+
" }\n",
|
| 367 |
+
"</style>\n",
|
| 368 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 369 |
+
" <thead>\n",
|
| 370 |
+
" <tr style=\"text-align: right;\">\n",
|
| 371 |
+
" <th></th>\n",
|
| 372 |
+
" <th>text</th>\n",
|
| 373 |
+
" <th>intent</th>\n",
|
| 374 |
+
" </tr>\n",
|
| 375 |
+
" </thead>\n",
|
| 376 |
+
" <tbody>\n",
|
| 377 |
+
" <tr>\n",
|
| 378 |
+
" <th>0</th>\n",
|
| 379 |
+
" <td>listen to westbam alumb allergic on google music</td>\n",
|
| 380 |
+
" <td>PlayMusic</td>\n",
|
| 381 |
+
" </tr>\n",
|
| 382 |
+
" <tr>\n",
|
| 383 |
+
" <th>1</th>\n",
|
| 384 |
+
" <td>add step to me to the 50 clásicos playlist</td>\n",
|
| 385 |
+
" <td>AddToPlaylist</td>\n",
|
| 386 |
+
" </tr>\n",
|
| 387 |
+
" <tr>\n",
|
| 388 |
+
" <th>2</th>\n",
|
| 389 |
+
" <td>i give this current textbook a rating value of...</td>\n",
|
| 390 |
+
" <td>RateBook</td>\n",
|
| 391 |
+
" </tr>\n",
|
| 392 |
+
" <tr>\n",
|
| 393 |
+
" <th>3</th>\n",
|
| 394 |
+
" <td>play the song little robin redbreast</td>\n",
|
| 395 |
+
" <td>PlayMusic</td>\n",
|
| 396 |
+
" </tr>\n",
|
| 397 |
+
" <tr>\n",
|
| 398 |
+
" <th>4</th>\n",
|
| 399 |
+
" <td>please add iris dement to my playlist this is ...</td>\n",
|
| 400 |
+
" <td>AddToPlaylist</td>\n",
|
| 401 |
+
" </tr>\n",
|
| 402 |
+
" <tr>\n",
|
| 403 |
+
" <th>...</th>\n",
|
| 404 |
+
" <td>...</td>\n",
|
| 405 |
+
" <td>...</td>\n",
|
| 406 |
+
" </tr>\n",
|
| 407 |
+
" <tr>\n",
|
| 408 |
+
" <th>37306</th>\n",
|
| 409 |
+
" <td>what is your date of birth</td>\n",
|
| 410 |
+
" <td>how_old_are_you</td>\n",
|
| 411 |
+
" </tr>\n",
|
| 412 |
+
" <tr>\n",
|
| 413 |
+
" <th>37307</th>\n",
|
| 414 |
+
" <td>what year were you born in</td>\n",
|
| 415 |
+
" <td>how_old_are_you</td>\n",
|
| 416 |
+
" </tr>\n",
|
| 417 |
+
" <tr>\n",
|
| 418 |
+
" <th>37308</th>\n",
|
| 419 |
+
" <td>what is the year that were you born</td>\n",
|
| 420 |
+
" <td>how_old_are_you</td>\n",
|
| 421 |
+
" </tr>\n",
|
| 422 |
+
" <tr>\n",
|
| 423 |
+
" <th>37309</th>\n",
|
| 424 |
+
" <td>how old are you ai</td>\n",
|
| 425 |
+
" <td>how_old_are_you</td>\n",
|
| 426 |
+
" </tr>\n",
|
| 427 |
+
" <tr>\n",
|
| 428 |
+
" <th>37310</th>\n",
|
| 429 |
+
" <td>are you 16 years old</td>\n",
|
| 430 |
+
" <td>how_old_are_you</td>\n",
|
| 431 |
+
" </tr>\n",
|
| 432 |
+
" </tbody>\n",
|
| 433 |
+
"</table>\n",
|
| 434 |
+
"<p>37311 rows × 2 columns</p>\n",
|
| 435 |
+
"</div>"
|
| 436 |
+
],
|
| 437 |
+
"text/plain": [
|
| 438 |
+
" text intent\n",
|
| 439 |
+
"0 listen to westbam alumb allergic on google music PlayMusic\n",
|
| 440 |
+
"1 add step to me to the 50 clásicos playlist AddToPlaylist\n",
|
| 441 |
+
"2 i give this current textbook a rating value of... RateBook\n",
|
| 442 |
+
"3 play the song little robin redbreast PlayMusic\n",
|
| 443 |
+
"4 please add iris dement to my playlist this is ... AddToPlaylist\n",
|
| 444 |
+
"... ... ...\n",
|
| 445 |
+
"37306 what is your date of birth how_old_are_you\n",
|
| 446 |
+
"37307 what year were you born in how_old_are_you\n",
|
| 447 |
+
"37308 what is the year that were you born how_old_are_you\n",
|
| 448 |
+
"37309 how old are you ai how_old_are_you\n",
|
| 449 |
+
"37310 are you 16 years old how_old_are_you\n",
|
| 450 |
+
"\n",
|
| 451 |
+
"[37311 rows x 2 columns]"
|
| 452 |
+
]
|
| 453 |
+
},
|
| 454 |
+
"execution_count": 2,
|
| 455 |
+
"metadata": {},
|
| 456 |
+
"output_type": "execute_result"
|
| 457 |
+
}
|
| 458 |
+
],
|
| 459 |
+
"source": [
|
| 460 |
+
"import pandas as pd\n",
|
| 461 |
+
"traindf = pd.read_csv(\"intentdata/train.csv\")\n",
|
| 462 |
+
"traindf"
|
| 463 |
+
]
|
| 464 |
+
}
|
| 465 |
+
],
|
| 466 |
+
"metadata": {
|
| 467 |
+
"kernelspec": {
|
| 468 |
+
"display_name": "Python 3.6.13 ('caesarnlradeon')",
|
| 469 |
+
"language": "python",
|
| 470 |
+
"name": "python3"
|
| 471 |
+
},
|
| 472 |
+
"language_info": {
|
| 473 |
+
"codemirror_mode": {
|
| 474 |
+
"name": "ipython",
|
| 475 |
+
"version": 3
|
| 476 |
+
},
|
| 477 |
+
"file_extension": ".py",
|
| 478 |
+
"mimetype": "text/x-python",
|
| 479 |
+
"name": "python",
|
| 480 |
+
"nbconvert_exporter": "python",
|
| 481 |
+
"pygments_lexer": "ipython3",
|
| 482 |
+
"version": "3.6.13"
|
| 483 |
+
},
|
| 484 |
+
"orig_nbformat": 4,
|
| 485 |
+
"vscode": {
|
| 486 |
+
"interpreter": {
|
| 487 |
+
"hash": "bfcbd6a2138b43c88138636b277dc6540d170334c8840148725250eab505a128"
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
},
|
| 491 |
+
"nbformat": 4,
|
| 492 |
+
"nbformat_minor": 2
|
| 493 |
+
}
|
CaesarAINL/caesarinfer.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import json
|
| 3 |
+
#import spacy
|
| 4 |
+
import pickle
|
| 5 |
+
import random
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
import tensorflow_hub as hub
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import tensorflow_text as text
|
| 10 |
+
from sklearn.preprocessing import LabelBinarizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# TODO AIM - Implement Chatbot Gossip to Caesar
|
| 14 |
+
# 1. Add data to datasets train | valid | test
|
| 15 |
+
# a. then clean labels
|
| 16 |
+
# 2. Augment data to provide more potential possibilites
|
| 17 |
+
# 3. Use BERT to match input with the response
|
| 18 |
+
# Command Labels - AddToPlaylist | GetWeather -> API -> user
|
| 19 |
+
# Conversation Labes - Greeting | Goodbye -> BERTNN: input:"hello" => response:"hi there, I am caesar" -> user
|
| 20 |
+
|
| 21 |
+
# TODO AIM - Single names of songs artists like "play a boogie" and it will play a boogie's music.
|
| 22 |
+
# 1. Idea one - NER detect the named entities
|
| 23 |
+
# 2. Create new Neural Network that detects that. * Have to determine the relationship between the entites
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
greetings = ["Greeting","smalltalk_greetings_hello","greeting"]
|
| 27 |
+
courtesy_greeting = ["CourtesyGreeting"]
|
| 28 |
+
|
| 29 |
+
class CaesarNL:
|
| 30 |
+
@staticmethod
|
| 31 |
+
def run(userinput):
|
| 32 |
+
#if len(sys.argv) == 2:
|
| 33 |
+
#userinput = [sys.argv[1].lower()]
|
| 34 |
+
stored_name = "Amari"
|
| 35 |
+
classifier_model = tf.keras.models.load_model('caesarmodel/caesarnl',custom_objects={'KerasLayer':hub.KerasLayer})
|
| 36 |
+
|
| 37 |
+
# Show the model architecture
|
| 38 |
+
results = tf.nn.softmax(classifier_model(tf.constant(userinput)))
|
| 39 |
+
print(results.shape)
|
| 40 |
+
with open("caesarmodel/labelbinarizer.pkl","rb") as f:
|
| 41 |
+
binarizer = pickle.load(f)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
intents=binarizer.inverse_transform(results.numpy())
|
| 45 |
+
with open("intentdata/responses.json","r") as f:
|
| 46 |
+
responses = json.load(f)["responses"]
|
| 47 |
+
|
| 48 |
+
if intents[0] in greetings:
|
| 49 |
+
greetresponse = random.choice(responses["Greeting"]).replace("<HUMAN>",stored_name)
|
| 50 |
+
#print(greetresponse)
|
| 51 |
+
return greetresponse, intents[0]
|
| 52 |
+
else:
|
| 53 |
+
response = f"response to be implemented for text:{userinput}, predicted intent:{intents[0]}"
|
| 54 |
+
#print(response)
|
| 55 |
+
return response,intents[0]
|
| 56 |
+
#elif len(sys.argv) < 2:
|
| 57 |
+
#response = "What is it, sir?"
|
| 58 |
+
#print(response)
|
| 59 |
+
#return response
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
# Takes 17 seconds
|
| 64 |
+
userinput = ["Hello"]
|
| 65 |
+
greetresponse,intents = CaesarNL.run(userinput)
|
| 66 |
+
print(greetresponse,f"intent:{intents}")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
CaesarAINL/caesarintro.mp3
ADDED
|
File without changes
|
CaesarAINL/caesarnl.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#import library
|
| 2 |
+
import warnings
|
| 3 |
+
#from gtts import gTTS
|
| 4 |
+
import os
|
| 5 |
+
#from playsound import playsound
|
| 6 |
+
from caesarapis import CaesarAPIs
|
| 7 |
+
import time
|
| 8 |
+
import pyttsx3
|
| 9 |
+
warnings.filterwarnings("ignore")
|
| 10 |
+
import os
|
| 11 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 12 |
+
import speech_recognition as sr
|
| 13 |
+
from caesarinfer import CaesarNL
|
| 14 |
+
# Initialize recognizer class (for recognizing the speech)
|
| 15 |
+
engine=pyttsx3.init('sapi5')
|
| 16 |
+
voices=engine.getProperty('voices')
|
| 17 |
+
engine.setProperty('voice',voices[1].id)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def speak(text,whisper_mode=None):
|
| 21 |
+
if whisper_mode == 0:
|
| 22 |
+
engine.say(text)
|
| 23 |
+
engine.runAndWait()
|
| 24 |
+
recognizer = sr.Recognizer()
|
| 25 |
+
|
| 26 |
+
#def caesartalk(caesarspeech,whisper_mode,filename="example.mp3"):
|
| 27 |
+
# if whisper_mode == 0:
|
| 28 |
+
# audio = gTTS(caesarspeech, lang="en", slow=False)
|
| 29 |
+
# audio.save(filename)
|
| 30 |
+
# playsound(filename)
|
| 31 |
+
# time.sleep(1)
|
| 32 |
+
# if filename in os.listdir():
|
| 33 |
+
# os.remove(filename)
|
| 34 |
+
# Reading Microphone as source
|
| 35 |
+
# listening the speech and store in audio_text variable
|
| 36 |
+
whisper_state = 1
|
| 37 |
+
caesarapis = CaesarAPIs()
|
| 38 |
+
while True:
|
| 39 |
+
with sr.Microphone() as source:
|
| 40 |
+
|
| 41 |
+
caesarintro ="How can I help you sir?"
|
| 42 |
+
print(caesarintro)
|
| 43 |
+
#caesartalk(caesarintro,caesarapis.whisper_mode,filename="caesarintro.mp3")
|
| 44 |
+
speak(caesarintro,caesarapis.whisper_mode)
|
| 45 |
+
#recognizer
|
| 46 |
+
recognizer.adjust_for_ambient_noise(source,duration=1)
|
| 47 |
+
audio_text = recognizer.listen(source)
|
| 48 |
+
understood = "Understood sir, processing..."
|
| 49 |
+
print(understood)
|
| 50 |
+
#caesartalk(understood,caesarapis.whisper_mode,filename="caesar_understood.mp3")
|
| 51 |
+
speak(understood,caesarapis.whisper_mode)
|
| 52 |
+
try:
|
| 53 |
+
# using google speech recognition
|
| 54 |
+
text = recognizer.recognize_google(audio_text)
|
| 55 |
+
print(text)
|
| 56 |
+
print("Caesar processing...")
|
| 57 |
+
caesarResponse,intent = CaesarNL.run([text])
|
| 58 |
+
caesarapis.runapis(caesarResponse,intent,text)
|
| 59 |
+
|
| 60 |
+
print(f"User Input: {text}")
|
| 61 |
+
print(f"Caesar: {caesarResponse}")
|
| 62 |
+
#caesartalk(caesarResponse,caesarapis.whisper_mode,filename="caesarResponse.mp3")
|
| 63 |
+
speak(caesarResponse,caesarapis.whisper_mode)
|
| 64 |
+
|
| 65 |
+
except Exception as ex:
|
| 66 |
+
print(type(ex),ex)
|
| 67 |
+
print("Sorry, I did not get that")
|
| 68 |
+
|
| 69 |
+
# whisper_mode
|
| 70 |
+
|
CaesarAINL/caesarnlexamples.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import json
|
| 3 |
+
import spacy
|
| 4 |
+
import pickle
|
| 5 |
+
import random
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
import tensorflow_hub as hub
|
| 8 |
+
import tensorflow_text as text
|
| 9 |
+
from sklearn.preprocessing import LabelBinarizer
|
| 10 |
+
|
| 11 |
+
def print_my_examples(inputs, results):
|
| 12 |
+
result_for_printing = \
|
| 13 |
+
[f'input: {inputs[i]:<30} : estimated intent: {results[i]}'
|
| 14 |
+
for i in range(len(inputs))]
|
| 15 |
+
print(*result_for_printing, sep='\n')
|
| 16 |
+
print()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
examples = [
|
| 20 |
+
'play a song from U2', # this is the same sentence tried earlier
|
| 21 |
+
'Will it rain tomorrow',
|
| 22 |
+
'I like to hear greatist hits from beastie boys',
|
| 23 |
+
'I like to book a table for 3 persons',
|
| 24 |
+
'5 stars for machines like me',
|
| 25 |
+
'play a boogie wit da hoodie',
|
| 26 |
+
"play Bob's favorite song",
|
| 27 |
+
"give me a hug",
|
| 28 |
+
"hello"
|
| 29 |
+
]
|
| 30 |
+
greetings = ["Greeting","smalltalk_greetings_hello"]
|
| 31 |
+
courtesy_greeting = ["CourtesyGreeting"]
|
| 32 |
+
stored_name = "Amari"
|
| 33 |
+
examples = ["hello"]
|
| 34 |
+
#nlp = spacy.load("en_core_web_sm")
|
| 35 |
+
classifier_model = tf.keras.models.load_model('caesarmodel/caesarnl.h5',custom_objects={'KerasLayer':hub.KerasLayer})
|
| 36 |
+
|
| 37 |
+
# Show the model architecture
|
| 38 |
+
results = tf.nn.softmax(classifier_model(tf.constant(examples)))
|
| 39 |
+
with open("caesarmodel/labelbinarizer.pkl","rb") as f:
|
| 40 |
+
binarizer = pickle.load(f)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
intents=binarizer.inverse_transform(results.numpy())
|
| 44 |
+
with open("intentdata/responses.json","r") as f:
|
| 45 |
+
responses = json.load(f)["responses"]
|
| 46 |
+
|
| 47 |
+
if intents[0] in greetings:
|
| 48 |
+
greetresponse = random.choice(responses["Greeting"]).replace("<HUMAN>",stored_name)
|
| 49 |
+
print(greetresponse)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
#sentence_intents = dict(zip(examples,intents))
|
| 54 |
+
#print(sentence_intents)
|
| 55 |
+
#print_my_examples(examples, intents)
|
| 56 |
+
|
| 57 |
+
# TODO AIM - Implement Chatbot Gossip to Caesar
|
| 58 |
+
# 1. Add data to datasets train | valid | test
|
| 59 |
+
# a. then clean labels
|
| 60 |
+
# 2. Augment data to provide more potential possibilites
|
| 61 |
+
# 3. Use BERT to match input with the response
|
| 62 |
+
# Command Labels - AddToPlaylist | GetWeather -> API -> user
|
| 63 |
+
# Conversation Labes - Greeting | Goodbye -> BERTNN: input:"hello" => response:"hi there, I am caesar" -> user
|
| 64 |
+
|
| 65 |
+
# TODO AIM - Single names of songs artists like "play a boogie" and it will play a boogie's music.
|
| 66 |
+
# 1. Idea one - NER detect the named entities
|
| 67 |
+
# 2. Create new Neural Network that detects that. * Have to determine the relationship between the entites
|
| 68 |
+
|
| 69 |
+
|
CaesarAINL/caesarnlrasp.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#import library
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import warnings
|
| 5 |
+
from caesarapis import CaesarAPIs
|
| 6 |
+
import speech_recognition as sr
|
| 7 |
+
warnings.filterwarnings("ignore")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 12 |
+
# Initialize recognizer class (for recognizing the speech)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def speak(text,whisper_mode=0):
|
| 16 |
+
if whisper_mode == 0:
|
| 17 |
+
#espeak.synth(text)
|
| 18 |
+
if "output.mp3" in os.listdir():
|
| 19 |
+
os.remove("output.mp3")
|
| 20 |
+
os.system(f'espeak "{text}" --stdout | ffmpeg -i pipe:0 output.mp3')
|
| 21 |
+
os.system(f'mplayer -volume 300 output.mp3')
|
| 22 |
+
os.system("pkill mplayer")
|
| 23 |
+
#raise KeyboardInterrupt
|
| 24 |
+
#os.system(f'mplayer -af volume=30:1 output.mp3')
|
| 25 |
+
|
| 26 |
+
recognizer = sr.Recognizer()
|
| 27 |
+
# Reading Microphone as source
|
| 28 |
+
# listening the speech and store in audio_text variable
|
| 29 |
+
whisper_state = 1
|
| 30 |
+
caesarapis = CaesarAPIs()
|
| 31 |
+
while True:
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
caesarintro ="How can I help you sir?"
|
| 35 |
+
print(caesarintro)
|
| 36 |
+
#caesartalk(caesarintro,caesarapis.whisper_mode,filename="caesarintro.mp3")
|
| 37 |
+
speak(caesarintro,caesarapis.whisper_mode)
|
| 38 |
+
print("Listening...")
|
| 39 |
+
|
| 40 |
+
with sr.Microphone() as source:
|
| 41 |
+
recognizer.adjust_for_ambient_noise(source,duration=1)
|
| 42 |
+
audio_text = recognizer.listen(source)
|
| 43 |
+
|
| 44 |
+
#print("Billyy")
|
| 45 |
+
#recognizer
|
| 46 |
+
understood = "Understood sir, processing..."
|
| 47 |
+
print(understood)
|
| 48 |
+
#caesartalk(understood,caesarapis.whisper_mode,filename="caesar_understood.mp3")
|
| 49 |
+
speak(understood,caesarapis.whisper_mode)
|
| 50 |
+
# using google speech recognition
|
| 51 |
+
text = recognizer.recognize_google(audio_text)
|
| 52 |
+
print("output:",text)
|
| 53 |
+
print("Caesar processing...")
|
| 54 |
+
# TODO Send to Azure API
|
| 55 |
+
if "hello" in text:
|
| 56 |
+
print("Hola Amari")
|
| 57 |
+
break
|
| 58 |
+
#caesarResponse,intent = ("Hola Amari","greeting")
|
| 59 |
+
#caesarapis.runapis(caesarResponse,intent,text,speak)
|
| 60 |
+
|
| 61 |
+
#print(f"User Input: {text}")
|
| 62 |
+
#print(f"Caesar: {caesarResponse}")
|
| 63 |
+
#caesartalk(caesarResponse,caesarapis.whisper_mode,filename="caesarResponse.mp3")
|
| 64 |
+
#speak(caesarResponse,caesarapis.whisper_mode)
|
| 65 |
+
except Exception as uex:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# whisper_mode
|
| 70 |
+
|
CaesarAINL/caesartrain.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
#import shutil
|
| 5 |
+
import pickle
|
| 6 |
+
import warnings
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
import tensorflow_hub as hub
|
| 10 |
+
import tensorflow_text as text
|
| 11 |
+
from pylab import rcParams
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from sklearn.preprocessing import LabelBinarizer
|
| 14 |
+
warnings.filterwarnings("ignore")
|
| 15 |
+
tf.get_logger().setLevel('ERROR')
|
| 16 |
+
|
| 17 |
+
class CaesarNLTrain:
|
| 18 |
+
def train(traindf,validdf,testdf,examples,history_filename = "history.png"):
|
| 19 |
+
intent_label_output_size = len(pd.unique(traindf["intent"]))
|
| 20 |
+
trainfeatures=traindf.copy()
|
| 21 |
+
trainlabels=trainfeatures.pop("intent")
|
| 22 |
+
|
| 23 |
+
trainfeatures=trainfeatures.values
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
"""One-Hot-Encoding of class-labels:"""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
binarizer=LabelBinarizer()
|
| 32 |
+
trainlabels=binarizer.fit_transform(trainlabels.values)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
"""Preprocess test- and validation data in the same way as it has been done for training-data:"""
|
| 36 |
+
|
| 37 |
+
testfeatures=testdf.copy()
|
| 38 |
+
testlabels=testfeatures.pop("intent")
|
| 39 |
+
validfeatures=validdf.copy()
|
| 40 |
+
validlabels=validfeatures.pop("intent")
|
| 41 |
+
|
| 42 |
+
testfeatures=testfeatures.values
|
| 43 |
+
validfeatures=validfeatures.values
|
| 44 |
+
|
| 45 |
+
testlabels=binarizer.transform(testlabels.values)
|
| 46 |
+
validlabels=binarizer.transform(validlabels.values)
|
| 47 |
+
pickle.dump(binarizer, open('caesarmodel/labelbinarizer.pkl', 'wb'))
|
| 48 |
+
|
| 49 |
+
bert_model_name = 'small_bert/bert_en_uncased_L-8_H-512_A-8'
|
| 50 |
+
with open("caesarberthubmodels/bert_to_handle.json") as f:
|
| 51 |
+
map_name_to_handle = json.load(f)
|
| 52 |
+
with open("caesarberthubmodels/bert_to_preprocess.json") as f:
|
| 53 |
+
map_model_to_preprocess = json.load(f)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
tfhub_handle_encoder = map_name_to_handle[bert_model_name]
|
| 58 |
+
tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]
|
| 59 |
+
|
| 60 |
+
print(f'BERT model selected : {tfhub_handle_encoder}')
|
| 61 |
+
print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
trainfeatures[0]
|
| 69 |
+
|
| 70 |
+
text_test = trainfeatures[0]
|
| 71 |
+
text_preprocessed = bert_preprocess_model(text_test)
|
| 72 |
+
|
| 73 |
+
bert_model = hub.KerasLayer(tfhub_handle_encoder)
|
| 74 |
+
|
| 75 |
+
bert_results = bert_model(text_preprocessed)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_classifier_model():
|
| 80 |
+
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
|
| 81 |
+
preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
|
| 82 |
+
encoder_inputs = preprocessing_layer(text_input)
|
| 83 |
+
encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
|
| 84 |
+
outputs = encoder(encoder_inputs)
|
| 85 |
+
net = outputs['pooled_output']
|
| 86 |
+
net = tf.keras.layers.Dropout(0.1)(net)
|
| 87 |
+
net = tf.keras.layers.Dense(intent_label_output_size, activation=None, name='classifier')(net)
|
| 88 |
+
return tf.keras.Model(text_input, net)
|
| 89 |
+
|
| 90 |
+
"""Let's check that the model runs with the output of the preprocessing model."""
|
| 91 |
+
|
| 92 |
+
classifier_model = build_classifier_model()
|
| 93 |
+
bert_raw_result = classifier_model(tf.constant(trainfeatures[0]))
|
| 94 |
+
print(tf.keras.activations.softmax(bert_raw_result))
|
| 95 |
+
|
| 96 |
+
"""The output is meaningless, of course, because the model has not been trained yet.
|
| 97 |
+
|
| 98 |
+
Let's take a look at the model's structure.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
classifier_model.summary()
|
| 102 |
+
|
| 103 |
+
"""## Model training
|
| 104 |
+
|
| 105 |
+
You now have all the pieces to train a model, including the preprocessing module, BERT encoder, data, and classifier.
|
| 106 |
+
|
| 107 |
+
Since this is a non-binary classification problem and the model outputs probabilities, you'll use `losses.CategoricalCrossentropy` loss function.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
|
| 111 |
+
metrics = tf.metrics.CategoricalAccuracy()
|
| 112 |
+
|
| 113 |
+
"""### Loading the BERT model and training
|
| 114 |
+
|
| 115 |
+
Using the `classifier_model` you created earlier, you can compile the model with the loss, metric and optimizer.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
epochs=5
|
| 119 |
+
optimizer=tf.keras.optimizers.Adam(1e-5)
|
| 120 |
+
classifier_model.compile(optimizer=optimizer,
|
| 121 |
+
loss=loss,
|
| 122 |
+
metrics=metrics)
|
| 123 |
+
|
| 124 |
+
"""Note: training time will vary depending on the complexity of the BERT model you have selected."""
|
| 125 |
+
|
| 126 |
+
print(f'Training model with {tfhub_handle_encoder}')
|
| 127 |
+
history = classifier_model.fit(x=trainfeatures,y=trainlabels,
|
| 128 |
+
validation_data=(validfeatures,validlabels),
|
| 129 |
+
batch_size=32,
|
| 130 |
+
epochs=epochs)
|
| 131 |
+
classifier_model.save("caesarmodel/caesarnl.h5")
|
| 132 |
+
|
| 133 |
+
"""### Evaluate the model
|
| 134 |
+
|
| 135 |
+
Let's see how the model performs. Two values will be returned. Loss (a number which represents the error, lower values are better), and accuracy.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
loss, accuracy = classifier_model.evaluate(testfeatures,testlabels)
|
| 139 |
+
|
| 140 |
+
print(f'Loss: {loss}')
|
| 141 |
+
print(f'Accuracy: {accuracy}')
|
| 142 |
+
|
| 143 |
+
"""### Plot the accuracy and loss over time
|
| 144 |
+
|
| 145 |
+
Based on the `History` object returned by `model.fit()`. You can plot the training and validation loss for comparison, as well as the training and validation accuracy:
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
history_dict = history.history
|
| 149 |
+
print(history_dict.keys())
|
| 150 |
+
|
| 151 |
+
acc = history_dict['categorical_accuracy']
|
| 152 |
+
val_acc = history_dict['val_categorical_accuracy']
|
| 153 |
+
loss = history_dict['loss']
|
| 154 |
+
val_loss = history_dict['val_loss']
|
| 155 |
+
|
| 156 |
+
epochs = range(1, len(acc) + 1)
|
| 157 |
+
fig = plt.figure(figsize=(10, 8))
|
| 158 |
+
fig.tight_layout()
|
| 159 |
+
|
| 160 |
+
plt.subplot(2, 1, 1)
|
| 161 |
+
# "bo" is for "blue dot"
|
| 162 |
+
plt.plot(epochs, loss, 'r', label='Training loss')
|
| 163 |
+
# b is for "solid blue line"
|
| 164 |
+
plt.plot(epochs, val_loss, 'b', label='Validation loss')
|
| 165 |
+
plt.title('Training and validation loss')
|
| 166 |
+
plt.grid(True)
|
| 167 |
+
# plt.xlabel('Epochs')
|
| 168 |
+
plt.ylabel('Loss')
|
| 169 |
+
plt.legend()
|
| 170 |
+
|
| 171 |
+
plt.subplot(2, 1, 2)
|
| 172 |
+
plt.plot(epochs, acc, 'r', label='Training acc')
|
| 173 |
+
plt.plot(epochs, val_acc, 'b', label='Validation acc')
|
| 174 |
+
plt.title('Training and validation accuracy')
|
| 175 |
+
plt.grid(True)
|
| 176 |
+
plt.xlabel('Epochs')
|
| 177 |
+
plt.ylabel('Accuracy')
|
| 178 |
+
|
| 179 |
+
plt.legend(loc='lower right')
|
| 180 |
+
plt.savefig(f"caesartrainperformance/{history_filename}")
|
| 181 |
+
|
| 182 |
+
"""In this plot, the red lines represents the training loss and accuracy, and the blue lines are the validation loss and accuracy.
|
| 183 |
+
|
| 184 |
+
Classifying arbitrary instructions:
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def print_my_examples(inputs, results):
|
| 188 |
+
result_for_printing = \
|
| 189 |
+
[f'input: {inputs[i]:<30} : estimated intent: {results[i]}'
|
| 190 |
+
for i in range(len(inputs))]
|
| 191 |
+
print(*result_for_printing, sep='\n')
|
| 192 |
+
print()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
results = tf.nn.softmax(classifier_model(tf.constant(examples)))
|
| 198 |
+
|
| 199 |
+
binarizer.classes_
|
| 200 |
+
|
| 201 |
+
intents=binarizer.inverse_transform(results.numpy())
|
| 202 |
+
|
| 203 |
+
print_my_examples(examples, intents)
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
examples = [
|
| 206 |
+
'play a song from U2', # this is the same sentence tried earlier
|
| 207 |
+
'Will it rain tomorrow',
|
| 208 |
+
'I like to hear greatist hits from beastie boys',
|
| 209 |
+
'I like to book a table for 3 persons',
|
| 210 |
+
'5 stars for machines like me'
|
| 211 |
+
]
|
| 212 |
+
datafolder="intentdata/"
|
| 213 |
+
trainfile=datafolder+"train.csv"
|
| 214 |
+
testfile=datafolder+"test.csv"
|
| 215 |
+
validfile=datafolder+"valid.csv"
|
| 216 |
+
|
| 217 |
+
"""Next, the downloaded .csv-files for training, validation and test are imported into pandas dataframes:"""
|
| 218 |
+
|
| 219 |
+
traindf = pd.read_csv(trainfile)
|
| 220 |
+
validdf = pd.read_csv(validfile)
|
| 221 |
+
testdf = pd.read_csv(testfile)
|
| 222 |
+
|
| 223 |
+
CaesarNLTrain.train(traindf,validdf,testdf,examples)
|
| 224 |
+
|
CaesarAINL/caesartrainperformance/.png
ADDED
|
Binary file (55.3 kB). View file
|
|
|
CaesarAINL/caesartrainperformance/history.png
ADDED
|
CaesarAINL/data_aggregation.ipynb
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import pandas as pd\n",
|
| 10 |
+
"import json\n",
|
| 11 |
+
"smalltalkintent = pd.read_csv(\"intentdata/Small_talk_Intent.csv\").rename(columns={\"Utterances\":\"text\",\"Intent\":\"intent\"})\n",
|
| 12 |
+
"training = pd.read_csv(\"intentdata/train_command.csv\")\n",
|
| 13 |
+
"df3 = pd.concat([training,smalltalkintent], ignore_index=True)\n",
|
| 14 |
+
"df3.to_csv(\"intentdata/train.csv\",mode=\"w\",index=False)\n",
|
| 15 |
+
"# TODO Do NLPAugmentation\n",
|
| 16 |
+
"#df3 = pd.concat([training,smalltalkintent], ignore_index=True)\n",
|
| 17 |
+
"#with open(\"intentdata/Intent.json\",\"r\") as f:\n",
|
| 18 |
+
"# chatbotXresponseintent = json.load(f)\n",
|
| 19 |
+
"#chatbotXresponseintent[\"intents\"][0] \n"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"# Produce responses\n",
|
| 29 |
+
"import json \n",
|
| 30 |
+
"with open(\"intentdata/Intent.json\",\"r\") as f:\n",
|
| 31 |
+
" data = json.load(f)\n",
|
| 32 |
+
"intents = []\n",
|
| 33 |
+
"for intent in data[\"intents\"]:\n",
|
| 34 |
+
" intents.append({\"intent\":intent[\"intent\"],\"responses\":intent[\"responses\"]})\n",
|
| 35 |
+
"with open(\"intentdata/responses.json\",\"w+\") as f:\n",
|
| 36 |
+
" json.dump({\"response\":intents},f)\n",
|
| 37 |
+
" "
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": null,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"with open(\"intentdata/responses.json\",\"r\") as f:\n",
|
| 47 |
+
" responses = json.load(f)[\"responses\"]\n",
|
| 48 |
+
"#print(intents[0] in greetings)\n",
|
| 49 |
+
"#if intents[0] in greetings:\n",
|
| 50 |
+
"print(responses[\"Greeting\"])"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": null,
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"import json\n",
|
| 60 |
+
"import pandas as pd \n",
|
| 61 |
+
"from IPython.display import display\n",
|
| 62 |
+
"with open(\"intentdata/aug/intent_aug_text.json\",\"r\") as f:\n",
|
| 63 |
+
" data = json.load(f)[\"intents\"]\n",
|
| 64 |
+
"columns_to_concat = []\n",
|
| 65 |
+
"for i in range(len(data)):\n",
|
| 66 |
+
" column1 = pd.DataFrame.from_dict({\"text\":data[i][\"text\"]})\n",
|
| 67 |
+
" #print(column1)\n",
|
| 68 |
+
" column2 = pd.DataFrame.from_dict({\"intent\":[data[i][\"intent\"] for j in range(len(data[i][\"text\"]))]})\n",
|
| 69 |
+
" #print(column2)\n",
|
| 70 |
+
" tent_concat= pd.concat([column1,column2],axis=1)\n",
|
| 71 |
+
" #print(tent_concat)\n",
|
| 72 |
+
" columns_to_concat.append(tent_concat)\n",
|
| 73 |
+
"intentdf = pd.concat(columns_to_concat,axis=0)\n",
|
| 74 |
+
"intentdf\n",
|
| 75 |
+
"df = pd.read_csv(\"intentdata/train.csv\")\n",
|
| 76 |
+
"newdf = pd.concat([df,intentdf],axis=0)\n",
|
| 77 |
+
"newdf.to_csv(\"intentdata/train.csv\",mode=\"w\",index=False)\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"\n"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "code",
|
| 84 |
+
"execution_count": null,
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"source": [
|
| 88 |
+
"import pandas as pd\n",
|
| 89 |
+
"data = pd.read_csv(\"intentdata/train.csv\")\n"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": null,
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [],
|
| 97 |
+
"source": [
|
| 98 |
+
"import pandas as pd\n",
|
| 99 |
+
"data = pd.read_csv(\"intentdata/train_no_half_response.csv\") # Original data is a (2000, 7) DataFrame\n",
|
| 100 |
+
"data = data.replace(\"smalltalk_agent_acquaintance\",\"CourtesyGreeting\")\n",
|
| 101 |
+
"data = data.replace(\"smalltalk_agent_age\",\"CurrentHumanQuery\")\n",
|
| 102 |
+
"data = data.replace(\"smalltalk_agent_annoying\",\"NotTalking2U\")\n",
|
| 103 |
+
"data = data.replace(\"smalltalk_agent_bad\",\"Swearing\")\n",
|
| 104 |
+
"data = data.replace(\"smalltalk_agent_boss\",\"CurrentHumanQuery\")\n",
|
| 105 |
+
"data = data.replace(\"smalltalk_agent_clever\",\"Clever\")\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"data = data.replace(\"smalltalk_agent_beautiful\",\"Clever\")\n",
|
| 108 |
+
"data = data.replace(\"smalltalk_agent_fired\",\"Shutup\")\n",
|
| 109 |
+
"data = data.replace(\"smalltalk_agent_good\",\"Thanks\")\n",
|
| 110 |
+
"data = data.replace(\"smalltalk_agent_chatbot\",\"SelfAware\")\n",
|
| 111 |
+
"data = data.replace(\"smalltalk_agent_real\",\"SelfAware\")\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"data.to_csv(\"intentdata/train.csv\",mode=\"w\",index=False)"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "code",
|
| 119 |
+
"execution_count": null,
|
| 120 |
+
"metadata": {},
|
| 121 |
+
"outputs": [],
|
| 122 |
+
"source": [
|
| 123 |
+
"# TODO Checks label data balance\n",
|
| 124 |
+
"import pandas as pd\n",
|
| 125 |
+
"data = pd.read_csv(\"intentdata/train.csv\") # Original data is a (2000, 7) DataFrame\n",
|
| 126 |
+
"# data contains 6 feature columns and 1 target column.\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"# Separate the design matrix from the target labels.\n",
|
| 129 |
+
"X = data.iloc[:, :-1]\n",
|
| 130 |
+
"y = data['intent']\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"y.value_counts().sort_index().plot.bar(x='Target Value', y='Number of Occurrences',figsize=(20,20))"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": null,
|
| 138 |
+
"metadata": {},
|
| 139 |
+
"outputs": [],
|
| 140 |
+
"source": [
|
| 141 |
+
"import torch\n",
|
| 142 |
+
"torch.cuda.is_available()"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"cell_type": "code",
|
| 147 |
+
"execution_count": null,
|
| 148 |
+
"metadata": {},
|
| 149 |
+
"outputs": [],
|
| 150 |
+
"source": [
|
| 151 |
+
"import json \n",
|
| 152 |
+
"import nlpaug.augmenter.word as naw\n",
|
| 153 |
+
"import nlpaug.flow as naf\n",
|
| 154 |
+
"print(\"Loading Models...\")\n",
|
| 155 |
+
"TOPK=20 #default=100\n",
|
| 156 |
+
"ACT = 'insert' #\"substitute\"\n",
|
| 157 |
+
"aug_w2v= naw.WordEmbsAug(\n",
|
| 158 |
+
" model_type='glove', model_path='glove/glove.6B.300d.txt',\n",
|
| 159 |
+
" action=\"substitute\")\n",
|
| 160 |
+
"aug_bert = naw.ContextualWordEmbsAug(\n",
|
| 161 |
+
" model_path='distilbert-base-uncased', \n",
|
| 162 |
+
" \n",
|
| 163 |
+
" action=ACT, top_k=TOPK)\n",
|
| 164 |
+
"aug = naf.Sequential([\n",
|
| 165 |
+
" aug_bert,aug_w2v\n",
|
| 166 |
+
" ])\n",
|
| 167 |
+
"print(\"Models Loaded.\")\n",
|
| 168 |
+
"with open(\"intentdata/intent.json\") as f:\n",
|
| 169 |
+
" intentwhole = json.load(f)[\"intents\"]\n",
|
| 170 |
+
"#text = intent[0][\"text\"][3]\n",
|
| 171 |
+
"for intent in intentwhole:\n",
|
| 172 |
+
" for text in intent[\"text\"]:\n",
|
| 173 |
+
" augmented_texts = set()\n",
|
| 174 |
+
" for i in range(20):\n",
|
| 175 |
+
" #aug = naw.SynonymAug(aug_src='wordnet',aug_min=1, aug_max=10, aug_p=i/10)\n",
|
| 176 |
+
" augmented_text = str(aug.augment(text)[0])\n",
|
| 177 |
+
" print(augmented_text)\n",
|
| 178 |
+
" #print(augmented_text)\n",
|
| 179 |
+
" augmented_texts.add(augmented_text)\n",
|
| 180 |
+
" augmented_texts = augmented_texts.union(set(intent[\"text\"]))\n",
|
| 181 |
+
" intent[\"text\"] = list(augmented_texts)\n",
|
| 182 |
+
" def test():\n",
|
| 183 |
+
" if intent[\"intent\"] != \"Jokes\": \n",
|
| 184 |
+
" for response in intent[\"responses\"]:\n",
|
| 185 |
+
" augmented_responses = set()\n",
|
| 186 |
+
" for i in range(20):\n",
|
| 187 |
+
" #aug = naw.SynonymAug(aug_src='wordnet',aug_min=1, aug_max=10, aug_p=i/50,stopwords=[\"<HUMAN>\",\"<HUMAN>,\",\"<HUMAN>!\"])\n",
|
| 188 |
+
" augmented_response = str(aug.augment(response)[0])\n",
|
| 189 |
+
" print(augmented_response)\n",
|
| 190 |
+
" try:\n",
|
| 191 |
+
" augmented_response = augmented_response[:augmented_response.index(\"<\")] + \"<HUMAN\" + augmented_response[augmented_response.index(\">\"):]\n",
|
| 192 |
+
" except ValueError as vex:\n",
|
| 193 |
+
" pass\n",
|
| 194 |
+
" #print(augmented_text)\n",
|
| 195 |
+
" augmented_responses.add(augmented_response)\n",
|
| 196 |
+
" augmented_responses = augmented_responses.union(set(intent[\"responses\"]))\n",
|
| 197 |
+
" intent[\"responses\"] = list(augmented_responses) \n",
|
| 198 |
+
"with open(\"intentdata/intent_aug_text_test.json\",\"w+\") as f:\n",
|
| 199 |
+
" json.dump({\"intents\":intentwhole},f)\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"print(intentwhole[1][\"responses\"])"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"cell_type": "code",
|
| 207 |
+
"execution_count": null,
|
| 208 |
+
"metadata": {},
|
| 209 |
+
"outputs": [],
|
| 210 |
+
"source": [
|
| 211 |
+
"df = pd.read_csv(\"intentdata/train.csv\")\n",
|
| 212 |
+
"len(list(pd.unique(df[\"intent\"])))"
|
| 213 |
+
]
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"cell_type": "code",
|
| 217 |
+
"execution_count": 1,
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"outputs": [
|
| 220 |
+
{
|
| 221 |
+
"data": {
|
| 222 |
+
"text/plain": [
|
| 223 |
+
"['20000-Utterances-Training-dataset-for-chatbots-virtual-assistant-Bitext-sample',\n",
|
| 224 |
+
" 'AskUbuntu Corpus.json',\n",
|
| 225 |
+
" 'Bitext_Sample_Customer_Service_Training_Dataset',\n",
|
| 226 |
+
" 'Chatbot Corpus.json',\n",
|
| 227 |
+
" 'Dataset-train.csv',\n",
|
| 228 |
+
" 'intent-corpus-basic.json',\n",
|
| 229 |
+
" 'intent-corpus-enrich-limit-20.json',\n",
|
| 230 |
+
" 'intents.json',\n",
|
| 231 |
+
" 'intent_classification.csv',\n",
|
| 232 |
+
" 'music_intent_entities.json',\n",
|
| 233 |
+
" 'restaurant_intent_entities.json',\n",
|
| 234 |
+
" 'Web App Corpus.json']"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
"execution_count": 1,
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"output_type": "execute_result"
|
| 240 |
+
}
|
| 241 |
+
],
|
| 242 |
+
"source": [
|
| 243 |
+
"import os\n",
|
| 244 |
+
"import json \n",
|
| 245 |
+
"import pandas as pd\n",
|
| 246 |
+
"os.listdir(\"new_intent_data\")"
|
| 247 |
+
]
|
| 248 |
+
}
|
| 249 |
+
],
|
| 250 |
+
"metadata": {
|
| 251 |
+
"kernelspec": {
|
| 252 |
+
"display_name": "Python 3.6.13 ('caesarnlradeon')",
|
| 253 |
+
"language": "python",
|
| 254 |
+
"name": "python3"
|
| 255 |
+
},
|
| 256 |
+
"language_info": {
|
| 257 |
+
"codemirror_mode": {
|
| 258 |
+
"name": "ipython",
|
| 259 |
+
"version": 3
|
| 260 |
+
},
|
| 261 |
+
"file_extension": ".py",
|
| 262 |
+
"mimetype": "text/x-python",
|
| 263 |
+
"name": "python",
|
| 264 |
+
"nbconvert_exporter": "python",
|
| 265 |
+
"pygments_lexer": "ipython3",
|
| 266 |
+
"version": "3.6.13"
|
| 267 |
+
},
|
| 268 |
+
"orig_nbformat": 4,
|
| 269 |
+
"vscode": {
|
| 270 |
+
"interpreter": {
|
| 271 |
+
"hash": "bfcbd6a2138b43c88138636b277dc6540d170334c8840148725250eab505a128"
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
},
|
| 275 |
+
"nbformat": 4,
|
| 276 |
+
"nbformat_minor": 2
|
| 277 |
+
}
|
CaesarAINL/runcaesarnl.bat
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"C:\Users\amari\.conda\envs\caesarinfer\python.exe" "D:\CaesarAI\CaesarAINL\caesarnl.py"
|