Werli commited on
Commit
2d36110
·
verified ·
1 Parent(s): 57baee5

Delete modules/reorganizer_model.py

Browse files
Files changed (1) hide show
  1. modules/reorganizer_model.py +0 -101
modules/reorganizer_model.py DELETED
@@ -1,101 +0,0 @@
1
- import os
2
- import io,copy,requests,spaces,gradio as gr,numpy as np
3
- from transformers import T5ForConditionalGeneration, T5Tokenizer
4
-
5
- # Experimental #
6
-
7
- LAMINI_PROMPT_LONG= "gokaygokay/Lamini-Prompt-Enchance-Long"
8
-
9
- class reorganizer_class:
10
- def __init__(self, repoId: str, device: str = None, loadModel: bool = False):
11
- self.modelPath = self.download_model(repoId)
12
- if device is None:
13
- import torch
14
- self.totalVram = 0
15
- if torch.cuda.is_available():
16
- try:
17
- deviceId = torch.cuda.current_device()
18
- self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory / (1024 * 1024 * 1024)
19
- except Exception as e:
20
- print(traceback.format_exc())
21
- print("Error detect vram: " + str(e))
22
- device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
23
- else:
24
- device = "cpu"
25
- self.device = device
26
- self.system_prompt = "Reorganize and enhance the following English labels describing a single image into a readable English article:\n\n"
27
- if loadModel:
28
- self.load_model()
29
-
30
- def download_model(self, repoId):
31
- import huggingface_hub
32
- allowPatterns = [
33
- #"tf_model.h5",
34
- #"model.ckpt.index",
35
- #"flax_model.msgpack",
36
- #"pytorch_model.bin",
37
- "config.json",
38
- "generation_config.json",
39
- "model.safetensors",
40
- "tokenizer.json",
41
- "tokenizer_config.json",
42
- "special_tokens_map.json",
43
- "vocab.json",
44
- "added_tokens.json",
45
- "spiece.model"
46
- ]
47
- kwargs = {"allow_patterns": allowPatterns,}
48
- try:
49
- return huggingface_hub.snapshot_download(repoId, **kwargs)
50
- except (huggingface_hub.utils.HfHubHTTPError, requests.exceptions.ConnectionError) as exception:
51
- import warnings
52
- warnings.warn(
53
- "An error occurred while synchronizing the model %s from the Hugging Face Hub:\n%s",
54
- repoId,
55
- exception,
56
- )
57
- warnings.warn(
58
- "Trying to load the model directly from the local cache, if it exists."
59
- )
60
- kwargs["local_files_only"] = True
61
- return huggingface_hub.snapshot_download(repoId, **kwargs)
62
-
63
- def load_model(self):
64
- import transformers
65
- try:
66
- print('\n\nLoading model: %s\n\n' % self.modelPath)
67
- self.Tokenizer = T5Tokenizer.from_pretrained(self.modelPath)
68
- self.Model = T5ForConditionalGeneration.from_pretrained(self.modelPath).to(self.device)
69
- except Exception as e:
70
- self.release_vram()
71
- raise e
72
-
73
- def release_vram(self):
74
- try:
75
- import torch
76
- if torch.cuda.is_available():
77
- if getattr(self, "Model", None) is not None:
78
- self.Model.to('cpu')
79
- del self.Model
80
- if getattr(self, "Tokenizer", None) is not None:
81
- del self.Tokenizer
82
- import gc
83
- gc.collect()
84
- torch.cuda.empty_cache()
85
- print("release vram end.")
86
- except Exception as e:
87
- print(traceback.format_exc())
88
- print("Error release vram: " + str(e))
89
-
90
- def reorganize(self, text: str, max_length: int = 400):
91
- try:
92
- input_ids = self.Tokenizer(self.system_prompt + text, return_tensors="pt").input_ids.to(self.device)
93
- output = self.Model.generate(input_ids, max_length=max_length, no_repeat_ngram_size=3, num_beams=2, early_stopping=True)
94
- result = self.Tokenizer.decode(output[0], skip_special_tokens=True)
95
- return result
96
- except Exception as e:
97
- print(traceback.format_exc())
98
- print("Error reorganize text: " + str(e))
99
- return None
100
-
101
- reorganizer_list=[LAMINI_PROMPT_LONG]