Werli commited on
Commit
67dd367
·
verified ·
1 Parent(s): 6350bdc

More clean

Browse files
Files changed (1) hide show
  1. app.py +9 -21
app.py CHANGED
@@ -42,14 +42,13 @@ else:
42
 
43
  TITLE = "Multi-Tagger"
44
  DESCRIPTION = """
45
- Multi-Tagger is a powerful and versatile application that integrates two cutting-edge models: Waifu Diffusion and Florence 2. This app is designed to provide comprehensive image analysis and captioning capabilities, making it a valuable tool for AI artists, researchers, and enthusiasts.
46
 
47
- Features:
48
- - Supports batch processing of multiple images.
49
- - Tags images with multiple categories: general tags, character tags, and ratings.
50
- - Displays categorized tags in a structured format.
51
- - Includes a separate tab for image captioning using Florence 2. Supports CUDA, MPS or CPU if one of them is available.
52
- - Supports various captioning tasks (e.g., Caption, Detailed Caption, Object Detection), it can display output text and images for tasks that generate visual outputs.
53
 
54
  Example image by [me.](https://huggingface.co/Werli)
55
  """
@@ -81,7 +80,6 @@ kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3
81
  def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
82
  def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
83
  def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
84
-
85
  class Timer:
86
  def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
87
  def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
@@ -94,7 +92,7 @@ class Timer:
94
  for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
95
  total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
96
  def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
97
-
98
  class Llama3Reorganize:
99
  def __init__(self,repoId:str,device:str=None,loadModel:bool=False):
100
  self.modelPath=self.download_model(repoId)
@@ -107,17 +105,14 @@ class Llama3Reorganize:
107
  else:device='cpu'
108
  self.device=device;self.system_prompt='Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:'
109
  if loadModel:self.load_model()
110
-
111
  def download_model(self,repoId):
112
  import warnings,requests;allowPatterns=['config.json','generation_config.json','model.bin','pytorch_model.bin','pytorch_model.bin.index.json','pytorch_model-*.bin','sentencepiece.bpe.model','tokenizer.json','tokenizer_config.json','shared_vocabulary.txt','shared_vocabulary.json','special_tokens_map.json','spiece.model','vocab.json','model.safetensors','model-*.safetensors','model.safetensors.index.json','quantize_config.json','tokenizer.model','vocabulary.json','preprocessor_config.json','added_tokens.json'];kwargs={'allow_patterns':allowPatterns}
113
  try:return huggingface_hub.snapshot_download(repoId,**kwargs)
114
  except(huggingface_hub.utils.HfHubHTTPError,requests.exceptions.ConnectionError)as exception:warnings.warn('An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s',repoId,exception);warnings.warn('Trying to load the model directly from the local cache, if it exists.');kwargs['local_files_only']=True;return huggingface_hub.snapshot_download(repoId,**kwargs)
115
-
116
  def load_model(self):
117
  import ctranslate2,transformers
118
  try:print('\n\nLoading model: %s\n\n'%self.modelPath);kwargsTokenizer={'pretrained_model_name_or_path':self.modelPath};kwargsModel={'device':self.device,'model_path':self.modelPath,'compute_type':'auto'};self.roleSystem={'role':'system','content':self.system_prompt};self.Model=ctranslate2.Generator(**kwargsModel);self.Tokenizer=transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer);self.terminators=[self.Tokenizer.eos_token_id,self.Tokenizer.convert_tokens_to_ids('<|eot_id|>')]
119
  except Exception as e:self.release_vram();raise e
120
-
121
  def release_vram(self):
122
  try:
123
  import torch
@@ -130,7 +125,6 @@ def release_vram(self):
130
  except Exception as e:print(traceback.format_exc());print('\tcuda empty cache, error: '+str(e))
131
  print('release vram end.')
132
  except Exception as e:print(traceback.format_exc());print('Error release vram: '+str(e))
133
-
134
  def reorganize(self,text:str,max_length:int=400):
135
  output=None;result=None
136
  try:
@@ -142,7 +136,7 @@ def reorganize(self,text:str,max_length:int=400):
142
  elif result[0]=='『'and result[len(result)-1]=='』':result=result[1:-1]
143
  except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
144
  return result
145
-
146
  class Predictor:
147
  def __init__(self):
148
  self.model_target_size = None
@@ -401,7 +395,6 @@ class Predictor:
401
  except Exception as e:
402
  print(traceback.format_exc())
403
  print("Error predict: " + str(e))
404
- # Result
405
  # Zip creation logic:
406
  download = []
407
  if txt_infos is not None and len(txt_infos) > 0:
@@ -449,8 +442,6 @@ def remove_image_from_gallery(gallery:list,selected_image:str):
449
  selected_image=ast.literal_eval(selected_image)
450
  if selected_image in gallery:gallery.remove(selected_image)
451
  return gallery
452
- # END
453
-
454
  def fig_to_pil(fig):buf=io.BytesIO();fig.savefig(buf,format='png');buf.seek(0);return Image.open(buf)
455
  @spaces.GPU
456
  def run_example(task_prompt,image,text_input=None):
@@ -534,10 +525,7 @@ dropdown_list = [
534
  SWINV2_MODEL_IS_DSV1_REPO,
535
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
536
  ]
537
- llama_list = [
538
- META_LLAMA_3_3B_REPO,
539
- META_LLAMA_3_8B_REPO,
540
- ]
541
 
542
  def _restart_space():
543
  HF_TOKEN=os.getenv('HF_TOKEN')
 
42
 
43
  TITLE = "Multi-Tagger"
44
  DESCRIPTION = """
45
+ Multi-Tagger is a versatile application combining Waifu Diffusion and Florence 2 models for advanced image analysis and captioning. Ideal for AI artists, researchers, and enthusiasts, it offers:
46
 
47
+ - Batch processing for multiple images.
48
+ - Multi-category tagging.
49
+ - Structured tag display.
50
+ - Image captioning with Florence 2, supporting CUDA, MPS, or CPU.
51
+ - Various captioning tasks (Caption, Detailed Caption, Object Detection) with visual outputs.
 
52
 
53
  Example image by [me.](https://huggingface.co/Werli)
54
  """
 
80
  def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
81
  def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
82
  def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
 
83
  class Timer:
84
  def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
85
  def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
 
92
  for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
93
  total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
94
  def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
95
+ # Llama
96
  class Llama3Reorganize:
97
  def __init__(self,repoId:str,device:str=None,loadModel:bool=False):
98
  self.modelPath=self.download_model(repoId)
 
105
  else:device='cpu'
106
  self.device=device;self.system_prompt='Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:'
107
  if loadModel:self.load_model()
 
108
  def download_model(self,repoId):
109
  import warnings,requests;allowPatterns=['config.json','generation_config.json','model.bin','pytorch_model.bin','pytorch_model.bin.index.json','pytorch_model-*.bin','sentencepiece.bpe.model','tokenizer.json','tokenizer_config.json','shared_vocabulary.txt','shared_vocabulary.json','special_tokens_map.json','spiece.model','vocab.json','model.safetensors','model-*.safetensors','model.safetensors.index.json','quantize_config.json','tokenizer.model','vocabulary.json','preprocessor_config.json','added_tokens.json'];kwargs={'allow_patterns':allowPatterns}
110
  try:return huggingface_hub.snapshot_download(repoId,**kwargs)
111
  except(huggingface_hub.utils.HfHubHTTPError,requests.exceptions.ConnectionError)as exception:warnings.warn('An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s',repoId,exception);warnings.warn('Trying to load the model directly from the local cache, if it exists.');kwargs['local_files_only']=True;return huggingface_hub.snapshot_download(repoId,**kwargs)
 
112
  def load_model(self):
113
  import ctranslate2,transformers
114
  try:print('\n\nLoading model: %s\n\n'%self.modelPath);kwargsTokenizer={'pretrained_model_name_or_path':self.modelPath};kwargsModel={'device':self.device,'model_path':self.modelPath,'compute_type':'auto'};self.roleSystem={'role':'system','content':self.system_prompt};self.Model=ctranslate2.Generator(**kwargsModel);self.Tokenizer=transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer);self.terminators=[self.Tokenizer.eos_token_id,self.Tokenizer.convert_tokens_to_ids('<|eot_id|>')]
115
  except Exception as e:self.release_vram();raise e
 
116
  def release_vram(self):
117
  try:
118
  import torch
 
125
  except Exception as e:print(traceback.format_exc());print('\tcuda empty cache, error: '+str(e))
126
  print('release vram end.')
127
  except Exception as e:print(traceback.format_exc());print('Error release vram: '+str(e))
 
128
  def reorganize(self,text:str,max_length:int=400):
129
  output=None;result=None
130
  try:
 
136
  elif result[0]=='『'and result[len(result)-1]=='』':result=result[1:-1]
137
  except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
138
  return result
139
+ # End Llama
140
  class Predictor:
141
  def __init__(self):
142
  self.model_target_size = None
 
395
  except Exception as e:
396
  print(traceback.format_exc())
397
  print("Error predict: " + str(e))
 
398
  # Zip creation logic:
399
  download = []
400
  if txt_infos is not None and len(txt_infos) > 0:
 
442
  selected_image=ast.literal_eval(selected_image)
443
  if selected_image in gallery:gallery.remove(selected_image)
444
  return gallery
 
 
445
  def fig_to_pil(fig):buf=io.BytesIO();fig.savefig(buf,format='png');buf.seek(0);return Image.open(buf)
446
  @spaces.GPU
447
  def run_example(task_prompt,image,text_input=None):
 
525
  SWINV2_MODEL_IS_DSV1_REPO,
526
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
527
  ]
528
+ llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]
 
 
 
529
 
530
  def _restart_space():
531
  HF_TOKEN=os.getenv('HF_TOKEN')