File size: 26,496 Bytes
e9fe176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
require 'torch'
require 'lfs'
cv = require 'cv'
require 'cv.imgproc'

local utils = {}

function utils.to4DTensor(a)
  return utils.toXDTensor(a, 4)
end

function utils.toXDTensor(a, x)
  local sz = torch.LongStorage(x):fill(1)
  for i = 1, a:dim() do
    sz[i] = a:size(i)
  end
  return a:reshape(sz)
end

function utils.toNNSingleTensor(a)
  local opt = opt or {}
  opt.type = opt.type or 'CUDA'

  if opt.type == 'CPU' then 
    return a:double()
  elseif opt.type == 'CUDA' then 
    return a:cuda()
  else
    return a
  end
end

function utils.toNNTensor(a)
  if type(a) == 'userdata' then
    return utils.toNNSingleTensor(a)
  elseif type(a) == 'table' then
    local b = {}
    for key, value in pairs(a) do
      b[key] = utils.toNNTensor(value)
    end
    return b
  else
    return a
  end
end

function utils.getLength(a)
   if type(a) == 'userdata' then
    return a:size(1)
  elseif type(a) == 'table' then
    return #a
  end
end

function utils.getFirstDim(a)
  if type(a) == 'userdata' then
    return a:size(1)
  elseif type(a) == 'table' then
    return utils.getSize(a[1])
  end
end

-- to be replaced
function utils.getSize(a)
  return utils.getFirstDim(a)
end


function utils.getIndices(a, indices)
  if type(a) == 'userdata' then
    return a:index(1, indices)
  elseif type(a) == 'table' then
    local b = {} 
    for key, value in pairs(a) do
      b[key] = utils.getIndices(value, indices)
    end
    return b
  end
end

function utils.loadData(fileName)
  if lfs.attributes(fileName, 'mode') == 'file' then
    return torch.load(fileName)
  else
    return nil
  end
end

-- all inputs are ByteTensors, 0 is false and 1 is true.
function utils.isinf(value)
  return value == math.huge or value == -math.huge
end

function utils.isnan(value)
  return value ~= value
end

function utils.isfinite(value)
  return not utils.isinf(value) and not utils.isnan(value)
end

function utils.bitand(t1, t2)
  return torch.ge(t1 + t2, 2)
end

function utils.bitor(t1, t2)
  return torch.ge(t1 + t2, 1)
end

function utils.bitnot(t)
  return torch.lt(t, 1)
end

function utils.len(t)
  local n
  if type(t) == 'table' then
    n = #t
  elseif t:nDimension() == 1 then
    n = t:size(1)
  else
    error(string.format("t is a high-order tensor! t:nDimension() = %d", t:nDimension()))
  end
  return n
end

function utils.fromHead(t, k)
  local n = utils.len(t)
  local res = {}
  for i = 1, math.min(k, n) do
    table.insert(res, t[i])
  end
  return res
end

function utils.fromTail(t, k)
  local n = utils.len(t)
  local res = {}
  for i = n, math.max(n - k, 0) + 1, -1 do
    table.insert(res, t[i])
  end
  return res
end

-- Find nonzero and return a table.
function utils.nonzero(t)
  local indices = {}
  local n = utils.len(t)

  for i = 1, n do
    if t[i] == 1 then 
      table.insert(indices, i) 
    end
  end
  return indices
end

-- Select rows and return
function utils.selectRows(t, tb)
  assert(tb:size(1) == t:size(1), 'The first dimension of t and tb must be the same')
  local dims = t:size():totable()
  table.remove(dims, 1)

  local res = t.new():resize(tb:sum(), unpack(dims))
  local counter = 0
  for i = 1, t:size(1) do
    if tb[i] == 1 then 
      counter = counter + 1
      res[counter]:copy(t[i]) 
    end
  end
  return res
end

function utils.permCompose(indices1, indices2)
  -- body
  local n1 = utils.len(indices1)
  local n2 = utils.len(indices2)

  local indices = {}
  for i = 1, n2 do
    indices[i] = indices1[indices2[i]]
  end
  return indices
end

function utils.removeKeys(t, exclude)
  local res = {}
  for k, v in pairs(t) do
    if not exclude[k] then res[k] = v end
  end
  return res
end

--[[ 

Find the correlation between two matrices.
t1 : m1 by n
t2 : m2 by n

return matrix of size m1 by m2

--]]
function utils.innerprod(t1, t2, func, reduction)
  local n = t1:size(2)
  assert(n == t2:size(2), string.format('The column of t1 [%d] must be the same as the column of t2 [%d]', t1:size(2), t2:size(2)))
  local m1 = t1:size(1)
  local m2 = t2:size(1)

  local res = torch.Tensor(m1, m2)
  for i = 1, m1 do
    for j = 1, m2 do
      res[i][j] = func(t1[i], t2[j])
    end
  end

  -- Find the one with smallest distance
  local best, bestIndices
  if reduction then
    best, bestIndices = reduction(res, 2)
  end

  return res, best, bestIndices
end

-----------------------------------------------
-- Landmarks related.

function utils.inRect(rect, p, margin)
  margin = margin or 0;

  return rect[1] + margin <= p[1] and p[1] <= rect[3] - margin and
  rect[2] + margin <= p[2] and p[2] <= rect[4] - margin;
end

function utils.fillKernel(m, x, y, r, c)
  assert(m, "Input image should not be null")
  assert(x and y, "Input coordinates should not be null")
  assert(r and c, "Input radius and color should not be null")
  assert(#m:size() == 2, 'fill_kernel: input m is not 2D!')

  -- fill a circle (x, y, r) with number c.
  local w = m:size(2)
  local h = m:size(1)

  local minX = math.min(math.max(x - r, 1), w)
  local maxX = math.min(math.max(x + r, 1), w)
  -- if minX > maxX then minX, maxX = maxX, minX end

  local minY = math.min(math.max(y - r, 1), h)
  local maxY = math.min(math.max(y + r, 1), h)
  -- if minY > maxY then minY, maxY = maxY, minY end

  -- local img_rect = {1, 1, m:size(2), m:size(1)}
  -- assert(util.rect_isin(img_rect, {minX, minY}) == true, string.format("out of bound, minX = %f, minY = %f", minX, minY))
  -- assert(util.rect_isin(img_rect, {maxX, maxY}) == true, string.format("out of bound, maxX = %f, maxY = %f", maxX, maxY))

  m:sub(minY, maxY, minX, maxX):fill(c);
end

-- Input a 2D mask, find its minimal value and associated location.
function utils.imin(m)
  local min1, minI1 = torch.min(m, 1)
  local minVal, minI2 = torch.min(min1, 2)

  local x = minI2[1][1]
  local y = minI1[1][x]
  return x, y, minVal
end

-- Input a 2D mask, find its maximal value and associated location.
function utils.imax(m)
  local max1, maxI1 = torch.max(m, 1)
  local maxVal, maxI2 = torch.max(max1, 2)

  local x = maxI2[1][1]
  local y = maxI1[1][x]
  return x, y, maxVal
end
-------------------------------------Save to json--------------------
function utils.set2Array(t)
  if type(t) ~= 'table' then return end
  t.__array = true
  for i, v in ipairs(t) do
    utils.set2Array(v)
  end
end

function utils.convert2Table(t)
  local res = {}
  if type(t) == 'table' then
    if debug then print("parsing table") end
    for k, v in pairs(t) do
      res[k] = utils.convert2Table(v)
    end
  elseif type(t) == 'number' then
    if debug then print("parsing number") end
    res = t
  elseif torch.typename(t) and torch.typename(t):match('Tensor') then
    -- if t is a tensor
    if debug then print("parsing tensor") end
    res = t:totable()
    utils.set2Array(t)
    -- Layer 
  else
    local typename = type(t)
    typename = typename or torch.typename(t)
    error("Convert_to_table error, unsupported datatype = " .. typename)
  end
  return res
end

function utils.saveJson(t, f)
  -- save a table to json
  if type(t) == 'number' then
    f:write(tostring(t))
    return
  end
  if t.__array then
    -- array must contain all numbers.
    f:write("[\n")
    for i, v in ipairs(t) do
      utils.saveJson(v, f)
      if i ~= #t then f:write(",") end
    end
    f:write("]\n")
  else
    local counter = 0
    for k, v in pairs(t) do counter = counter + 1 end
    f:write("{\n")
    for k, v in pairs(t) do
      f:write(tostring(k) .. " : ")
      utils.saveJson(v, f)
      counter = counter - 1
      if counter >= 1 then f:write(",\n") end
    end
    f:write("}\n")
  end
end

-------------------------------------Save to numpy-----------------
-- run it using PATH=/usr/bin/python

function utils.savePickle(f, t)
  local py = require 'fb.python'
  py.exec([=[
import numpy as np
import cPickle
with open(filename, "wb") as outfile:
    cPickle.dump(variable, outfile, protocol=cPickle.HIGHEST_PROTOCOL)
]=], {variable = t, filename = f})
end

--

local function getTermLength()
  if sys.uname() == 'windows' then return 80 end
  local tputf = io.popen('tput cols', 'r')
  local w = tonumber(tputf:read('*a'))
  local rc = {tputf:close()}
  if rc[3] == 0 then return w
  else return 80 end 
end

local barDone = true
local previous = -1
local tm = ''
local timer
local times
local indices
local termLength = math.min(getTermLength(), 120)

local function formatTime(seconds)
  -- decompose:
  local floor = math.floor
  local days = floor(seconds / 3600/24)
  seconds = seconds - days*3600*24
  local hours = floor(seconds / 3600)
  seconds = seconds - hours*3600
  local minutes = floor(seconds / 60)
  seconds = seconds - minutes*60
  local secondsf = floor(seconds)
  seconds = seconds - secondsf
  local millis = floor(seconds*1000)

  -- string
  local f = ''
  local i = 1
  if days > 0 then f = f .. days .. 'D' i=i+1 end
  if hours > 0 and i <= 2 then f = f .. hours .. 'h' i=i+1 end
  if minutes > 0 and i <= 2 then f = f .. minutes .. 'm' i=i+1 end
  if secondsf > 0 and i <= 2 then f = f .. secondsf .. 's' i=i+1 end
  if millis > 0 and i <= 2 then f = f .. millis .. 'ms' i=i+1 end
  if f == '' then f = '0ms' end

  -- return formatted time
  return f
end

function utils.progress(current, goal, addinfo)
  -- defaults:
  local barLength = termLength - 37 - #addinfo
  local smoothing = 100 
  local maxfps = 10

  -- Compute percentage
  local percent = math.floor(((current) * barLength) / goal)

  -- start new bar
  if (barDone and ((previous == -1) or (percent < previous))) then
    barDone = false
    previous = -1
    tm = ''
    timer = torch.Timer()
    times = {timer:time().real}
    indices = {current}
  else
    io.write('\r')
  end

  --if (percent ~= previous and not barDone) then
  if (not barDone) then
    previous = percent
    -- print bar
    io.write(' [')
    for i=1,barLength do
      if (i < percent) then io.write('=')
      elseif (i == percent) then io.write('>')
      else io.write('.') end
    end
    io.write('] ')
    for i=1,termLength-barLength-4 do io.write(' ') end
    for i=1,termLength-barLength-4 do io.write('\b') end
    -- time stats
    local elapsed = timer:time().real
    local step = (elapsed-times[1]) / (current-indices[1])
    if current==indices[1] then step = 0 end
    local remaining = math.max(0,(goal - current)*step)
    table.insert(indices, current)
    table.insert(times, elapsed)
    if #indices > smoothing then
      indices = table.splice(indices)
      times = table.splice(times)
    end
    -- Print remaining time when running or total time when done.
    if (percent < barLength) then
      tm = ' ETA: ' .. formatTime(remaining)
    else
      tm = ' Tot: ' .. formatTime(elapsed)
    end
    tm = tm .. ' | Step: ' .. formatTime(step) .. ' | ' .. addinfo
    io.write(tm)
    -- go back to center of bar, and print progress
    for i=1,5+#tm+barLength/2 do io.write('\b') end
    io.write(' ', current, '/', goal, ' ')
    -- reset for next bar
    if (percent == barLength) then
      barDone = true
      io.write('\n')
    end
    -- flush
    io.write('\r')
    io.flush()
  end
end

function utils.drawHeatMapImage(image, heatMap, alpha, beta)
   local heatMapImage = image * alpha
   local scale = math.ceil(math.max((#image)[2] / (#heatMap)[1], (#image)[3] / (#heatMap)[2]))
   resizedHeatMap = require('nn').SpatialUpSamplingBilinear(scale):forward(heatMap:repeatTensor(1, 1, 1))[1]
   
   local paddingX = (#image)[2] - (#resizedHeatMap)[1]   
   local paddingLeft = math.floor(paddingX / 2)   
   local paddingRight = paddingX - paddingLeft
   local paddingY = (#image)[3] - (#resizedHeatMap)[2]   
   local paddingTop = math.floor(paddingY / 2)
   local paddingBottom = paddingY - paddingTop
   resizedHeatMap = require('nn').SpatialZeroPadding(paddingLeft, paddingRight, paddingTop, paddingBottom):forward(resizedHeatMap:repeatTensor(1, 1, 1))[1]
   --print(scale)
   --print(padding)
   --print(#resizedHeatMap)
   --os.exit(1)
   
   local heatMapHSL = torch.DoubleTensor(3, (#resizedHeatMap)[1], (#resizedHeatMap)[2])
   heatMapHSL[1] = 2 / 3 * (1 - resizedHeatMap)
   heatMapHSL[2] = 1
   heatMapHSL[3] = 0.5
   local heatMapRGB = require('image').hsl2rgb(heatMapHSL)
   heatMapImage = heatMapImage + heatMapRGB * beta
   heatMapImage:clamp(0, 1)
   return heatMapImage
end



function utils.segmentFloorplan(floorplan, binaryThreshold, numOpenOperations, reverse)

   local floorplanBinary = torch.ones((#floorplan)[2], (#floorplan)[3])
   for c = 1, 3 do
      local mask = floorplan[c]:lt(binaryThreshold):double()
      floorplanBinary = torch.cmul(floorplanBinary, mask)
   end

   local kernel = torch.ones(3, 3)
   if numOpenOperations > 0 then   
      for i = 1, numOpenOperations do   
         floorplanBinary = image.erode(floorplanBinary)
         floorplanBinary = image.dilate(floorplanBinary)
      end   
   elseif numOpenOperations < 0 then   
      for i = 1, -numOpenOperations do
         floorplanBinary = image.dilate(floorplanBinary)
         floorplanBinary = image.erode(floorplanBinary)
      end   
   end

   floorplanBinaryByte = (floorplanBinary * 255):byte()
   if reverse ~= nil and reverse then
      floorplanBinaryByte = 255 - floorplanBinaryByte
   end
   local floorplanComponent = torch.IntTensor(floorplanBinaryByte:size())
   local numComponents = cv.connectedComponents{255 - floorplanBinaryByte, floorplanComponent}
   floorplanComponent = floorplanComponent + 1
   return floorplanComponent, numComponents, floorplanBinary
end

function utils.drawSegmentation(floorplanComponent, numComponents)
   local colorMap = {}
   for i = 1, numComponents do
      colorMap[i] = torch.rand(3)
   end
   colorMap[0] = torch.zeros(3)
   local floorplanLabels = floorplanComponent:repeatTensor(3, 1, 1):double()
   for c = 1, 3 do
      floorplanLabels[c]:apply(function(x) return colorMap[x][c] end)
   end
   return floorplanLabels
end

function utils.parseFloorplan(floorplanInput)
   local binaryThreshold = 0.7
   local floorplanSegmentation, numSegments, floorplanBinary = utils.segmentFloorplan(floorplanInput, binaryThreshold, 0)
   while (#floorplanSegmentation:eq(floorplanSegmentation[1][1]):nonzero())[1] > (#floorplanSegmentation)[1] * (#floorplanSegmentation)[2] * 0.8 and binaryThreshold < 1 do
      binaryThreshold = binaryThreshold + 0.1
      floorplanSegmentation, numSegments, floorplanBinary = utils.segmentFloorplan(floorplanInput, binaryThreshold, 0)
   end
   
   --image.save('test/floorplan.png', floorplanInput)
   --image.save('test/floorplan_segmentation.png', ut.drawSegmentation(floorplanSegmentation, numSegments))
   --image.save('test/floorplan_binary.png', floorplanBinary)
   local floorplanSegmentationReversed, numSegmentsReversed, floorplanBinaryReversed = utils.segmentFloorplan(floorplanInput, binaryThreshold, 0, true)
   --image.save('test/floorplan_binary_reversed.png', floorplanBinaryReversed)
   --image.save('test/floorplan.png', floorplanInput)
   --image.save('test/floorplan_segmentation_reversed.png', utils.drawSegmentation(floorplanSegmentationReversed, numSegmentsReversed))
   local backgroundSegmentReversed = floorplanSegmentationReversed[1][1]
   --print(torch.min(floorplanSegmentationReversed))
   --print(torch.max(floorplanSegmentationReversed))
   --print(numSegmentsReversed)
   local maxCount = 0
   local borderSegmentReversed = 0
   for segmentReversed = 1, numSegmentsReversed do
      if segmentReversed ~= backgroundSegmentReversed then
	 local indices = floorplanSegmentationReversed:eq(segmentReversed):nonzero()      
         if ##indices > 0 then
	    local count = (#indices)[1]
	    if count > maxCount then
	       borderSegmentReversed = segmentReversed
	       maxCount = count
	    end
	 end
      end
   end

   if borderSegmentReversed == 0 then
      return
   end
   
   local characterMasks = torch.cmul(1 - floorplanSegmentationReversed:eq(backgroundSegmentReversed), 1 - floorplanSegmentationReversed:eq(borderSegmentReversed))
   --image.save('test/invalid_mask_1.png', 1 - floorplanSegmentationReversed:eq(foregroundSegment))
   --image.save('test/invalid_mask_2.png', 1 - floorplanSegmentationReversed:eq(maxCountSegment))
   --image.save('test/character_masks.png', characterMasks:double())
   
   local borderMask = floorplanSegmentationReversed:eq(borderSegmentReversed)
   local borderIndices = borderMask:nonzero()
   local borderSegment = floorplanSegmentation[{borderIndices[1][1], borderIndices[1][2]}]

   --[[
      local borderThickness = 0
      local borderMaskEroded = borderMask:clone()
      while true do
      if ##borderMaskEroded:eq(1):nonzero() > 0 then
      borderMaskEroded = image.erode(borderMaskEroded)
      borderThickness = borderThickness + 2
      else
      borderThickness = borderThickness + 1
      break
      end
      end
   ]]--
   --print("borderThickness: " .. borderThickness)
   local borderThickness = 10
   borderMask = image.dilate(borderMask)
   local floorplanSegmentationErodedOnce = torch.cmul(floorplanSegmentation, (1 - borderMask):int())
   
   --[[
      local numDilations = math.floor(borderThickness / 2)
      for i = 2, numDilations do
      borderMask = image.dilate(borderMask)
      end
      local floorplanSegmentationEroded = torch.cmul(floorplanSegmentation, (1 - borderMask):int())
   ]]--
   
   local floorplanSegmentationRefined = floorplanSegmentation:clone()
   local backgroundSegment = floorplanSegmentation[1][1]
   floorplanSegmentationRefined[floorplanSegmentation:eq(backgroundSegment)] = 1
   floorplanSegmentationRefined[floorplanSegmentation:eq(borderSegment)] = 0

   local segmentInfo = {
      mins = {},
      maxs = {},
      means = {},
      nums = {},
      samples = {},
   }
   for segment = 1, numSegments do   
      if segment ~= backgroundSegment and segment ~= borderSegment then
         local indices = floorplanSegmentation:eq(segment):nonzero()
         segmentInfo.mins[segment] = torch.min(indices, 1)[1]            
         segmentInfo.maxs[segment] = torch.max(indices, 1)[1]            
         segmentInfo.means[segment] = torch.mean(indices:double(), 1)[1]
         segmentInfo.nums[segment] = (#indices)[1]
         segmentInfo.samples[segment] = indices[1]
      end
   end
   
   local largeSegments = {}
   local smallSegments = {}
   for segment = 1, numSegments do   
      if segment ~= backgroundSegment and segment ~= borderSegment then
         --if ##floorplanSegmentationEroded:eq(segment):nonzero() == 0 then
         --floorplanSegmentationRefined[floorplanSegmentation:eq(segment)] = 0
         if ##floorplanSegmentationErodedOnce:eq(segment):nonzero() > 0 and (#floorplanSegmentationErodedOnce:eq(segment):nonzero())[1] == segmentInfo.nums[segment] then          
            table.insert(smallSegments, segment)
         elseif segmentInfo.maxs[segment][1] - segmentInfo.mins[segment][1] < borderThickness or segmentInfo.maxs[segment][2] - segmentInfo.mins[segment][2] < borderThickness then
            floorplanSegmentationRefined[floorplanSegmentation:eq(segment)] = 0
         else
            table.insert(largeSegments, segment)
         end
      end
   end
   local characterSegmentsReversed = {}
   for segmentReversed = 1, numSegmentsReversed do            
      if segmentReversed ~= backgroundSegmentReversed and segmentReversed ~= borderSegmentReversed then         
         table.insert(characterSegmentsReversed, segmentReversed)
      end
   end
   --print(smallSegments)

   local characterSegmentInfo = {   
      mins = {},   
      maxs = {},   
      means = {},
      nums = {},
   }
   for _, characterSegmentReversed in pairs(characterSegmentsReversed) do
      local indices = floorplanSegmentationReversed:eq(characterSegmentReversed):nonzero()
      characterSegmentInfo.mins[characterSegmentReversed] = torch.min(indices, 1)[1]
      characterSegmentInfo.maxs[characterSegmentReversed] = torch.max(indices, 1)[1]      
      characterSegmentInfo.means[characterSegmentReversed] = torch.mean(indices:double(), 1)[1]
      characterSegmentInfo.means[characterSegmentReversed] = (#indices)[1]
   end

   for _, smallSegment in pairs(smallSegments) do
      local minDistance
      local minDistanceLargeSegment
      for _, largeSegment in pairs(largeSegments) do
         if segmentInfo.mins[smallSegment][1] >= segmentInfo.mins[largeSegment][1] and segmentInfo.mins[smallSegment][2] >= segmentInfo.mins[largeSegment][2] and segmentInfo.maxs[smallSegment][1] <= segmentInfo.maxs[largeSegment][1] and segmentInfo.maxs[smallSegment][2] <= segmentInfo.maxs[largeSegment][2] then
            local distance = torch.norm(segmentInfo.means[smallSegment] - segmentInfo.means[largeSegment])
            if minDistance == nil or distance < minDistance then
               minDistanceLargeSegment = largeSegment
               minDistance = distance
            end
         end
      end
      if minDistanceLargeSegment == nil then
         minDistanceLargeSegment = 1
      end
      floorplanSegmentationRefined[floorplanSegmentation:eq(smallSegment)] = minDistanceLargeSegment
   end

   for _, characterSegmentReversed in pairs(characterSegmentsReversed) do
      local minDistance
      local minDistanceLargeSegment
      for _, largeSegment in pairs(largeSegments) do
         if characterSegmentInfo.mins[characterSegmentReversed][1] >= segmentInfo.mins[largeSegment][1] and characterSegmentInfo.mins[characterSegmentReversed][2] >= segmentInfo.mins[largeSegment][2] and characterSegmentInfo.maxs[characterSegmentReversed][1] <= segmentInfo.maxs[largeSegment][1] and characterSegmentInfo.maxs[characterSegmentReversed][2] <= segmentInfo.maxs[largeSegment][2] then
            local distance = torch.norm(characterSegmentInfo.means[characterSegmentReversed] - segmentInfo.means[largeSegment])
            if minDistance == nil or distance < minDistance then
               minDistanceLargeSegment = largeSegment
               minDistance = distance
            end
         end
      end
      if minDistanceLargeSegment == nil then
         minDistanceLargeSegment = 1
      end
      floorplanSegmentationRefined[floorplanSegmentationReversed:eq(characterSegmentReversed)] = minDistanceLargeSegment
      characterMasks[floorplanSegmentationReversed:eq(characterSegmentReversed)] = minDistanceLargeSegment

   end

   --[[
      segmentInfo.indices = {}
      for _, largeSegment in pairs(largeSegments) do
      local indices = floorplanSegmentationRefined:eq(largeSegment):nonzero()
      segmentInfo.indices[largeSegment] = indices
      segmentInfo.mins[largeSegment] = torch.min(indices, 1)[1]            
      segmentInfo.maxs[largeSegment] = torch.max(indices, 1)[1]            
      segmentInfo.means[largeSegment] = torch.mean(indices:double(), 1)[1]
      segmentInfo.nums[largeSegment] = (#indices)[1]
      end
   ]]--

   local borderMask = floorplanSegmentationRefined:eq(0)
   borderMask[floorplanSegmentationRefined:eq(1)] = 1
   borderMask = image.erode(borderMask)
   local borderMaskByte = (borderMask * 255):byte()
   local floorplanSegmentationCoarse = torch.IntTensor(borderMaskByte:size())
   local numSegmentCoarse = cv.connectedComponents{255 - borderMaskByte, floorplanSegmentationCoarse}
   floorplanSegmentationCoarse = floorplanSegmentationCoarse + 1
   --image.save('test/floorplan_segmentation_coarse.png', ut.drawSegmentation(floorplanSegmentationCoarse, numSegmentCoarse))
   local coarseToFineSegmentMap = {}
   for segment = 1, numSegmentCoarse do
      coarseToFineSegmentMap[segment] = {}
   end
   for _, largeSegment in pairs(largeSegments) do
      table.insert(coarseToFineSegmentMap[floorplanSegmentationCoarse[segmentInfo.samples[largeSegment][1]][segmentInfo.samples[largeSegment][2]]], largeSegment)
   end
   for segmentCoarse = 1, numSegmentCoarse do
      local mins
      local maxs
      local maxNum
      local maxNumSegment
      local count = 0
      for _, largeSegment in pairs(coarseToFineSegmentMap[segmentCoarse]) do
         count = count + 1
         if maxNum == nil or segmentInfo.nums[largeSegment] > maxNum then
            maxNumSegment = largeSegment
            maxNum = segmentInfo.nums[largeSegment]
         end
         if mins == nil then
            mins = segmentInfo.mins[largeSegment]
            maxs = segmentInfo.maxs[largeSegment]
         else
            for i = 1, 2 do
               mins[i] = math.min(mins[i], segmentInfo.mins[largeSegment][i])
               maxs[i] = math.max(maxs[i], segmentInfo.maxs[largeSegment][i])
            end
         end
      end
      if count > 1 then
         local newSegmentWidth = maxs[1] - mins[1] + 1
         local newSegmentHeight = maxs[2] - mins[2] + 1
         local cornerMask = torch.zeros((#floorplanSegmentation)[1], (#floorplanSegmentation)[2])
         cornerMask:narrow(1, mins[1], newSegmentWidth):narrow(2, mins[2], newSegmentHeight):fill(1)
         cornerMask = torch.cmul(cornerMask:byte(), 1 - floorplanSegmentationCoarse:eq(segmentCoarse))
         cornerMask:narrow(1, segmentInfo.mins[maxNumSegment][1], segmentInfo.maxs[maxNumSegment][1] - segmentInfo.mins[maxNumSegment][1] + 1):narrow(2, segmentInfo.mins[maxNumSegment][2], segmentInfo.maxs[maxNumSegment][2] - segmentInfo.mins[maxNumSegment][2] + 1):fill(0)
         
         --local oriSegmentArea = (segmentInfo.maxs[maxNumSegment][1] - segmentInfo.mins[maxNumSegment][1] + 1) * (segmentInfo.maxs[maxNumSegment][2] - segmentInfo.mins[maxNumSegment][2] + 1)
         if ##cornerMask:nonzero() == 0 or (#cornerMask:nonzero())[1] < newSegmentWidth * newSegmentHeight * 0.1 then
            for _, largeSegment in pairs(coarseToFineSegmentMap[segmentCoarse]) do
               if largeSegment ~= maxNumSegment then
                  floorplanSegmentationRefined[floorplanSegmentationRefined:eq(largeSegment)] = maxNumSegment
                  characterMasks[characterMasks:eq(largeSegment)] = maxNumSegment
               end
            end
         end
      end
   end
   
   local newSegment = 1
   for segment = 1, numSegments do   
      if ##floorplanSegmentationRefined:eq(segment):nonzero() > 0 then   
         floorplanSegmentationRefined[floorplanSegmentationRefined:eq(segment)] = newSegment
         characterMasks[characterMasks:eq(segment)] = newSegment
         newSegment = newSegment + 1   
      end   
   end   
   local numSegmentsRefined = newSegment - 1
   return floorplanSegmentationRefined, characterMasks, numSegmentsRefined

end

return utils