| from darknet import * |
|
|
| def predict_tactic(net, s): |
| prob = 0 |
| d = c_array(c_float, [0.0]*256) |
| tac = '' |
| if not len(s): |
| s = '\n' |
| for c in s[:-1]: |
| d[ord(c)] = 1 |
| pred = predict(net, d) |
| d[ord(c)] = 0 |
| c = s[-1] |
| while 1: |
| d[ord(c)] = 1 |
| pred = predict(net, d) |
| d[ord(c)] = 0 |
| pred = [pred[i] for i in range(256)] |
| ind = sample(pred) |
| c = chr(ind) |
| prob += math.log(pred[ind]) |
| if len(tac) and tac[-1] == '.': |
| break |
| tac = tac + c |
| return (tac, prob) |
|
|
| def predict_tactics(net, s, n): |
| tacs = [] |
| for i in range(n): |
| reset_rnn(net) |
| tacs.append(predict_tactic(net, s)) |
| tacs = sorted(tacs, key=lambda x: -x[1]) |
| return tacs |
|
|
| net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0) |
| t = predict_tactics(net, "+++++\n", 10) |
| print t |
|
|