Spaces:
Configuration error
Configuration error
try fix gen
Browse files
infer.py
CHANGED
|
@@ -203,13 +203,14 @@ class TikzGenerator:
|
|
| 203 |
top_p=top_p,
|
| 204 |
top_k=top_k,
|
| 205 |
num_return_sequences=1,
|
| 206 |
-
max_length=self.pipeline.tokenizer.model_max_length, # type: ignore
|
| 207 |
do_sample=True,
|
| 208 |
return_full_text=False,
|
| 209 |
streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore
|
| 210 |
skip_prompt=True,
|
| 211 |
skip_special_tokens=True
|
| 212 |
),
|
|
|
|
| 213 |
)
|
| 214 |
|
| 215 |
if not stream:
|
|
@@ -218,8 +219,11 @@ class TikzGenerator:
|
|
| 218 |
def generate(self, image: Image.Image, **generate_kwargs):
|
| 219 |
prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
|
| 220 |
tokenizer = self.pipeline.tokenizer
|
|
|
|
| 221 |
text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
|
| 222 |
|
|
|
|
|
|
|
| 223 |
if self.clean_up_output:
|
| 224 |
for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
|
| 225 |
# remove leading characters because skip_special_tokens in pipeline
|
|
@@ -236,7 +240,9 @@ class TikzGenerator:
|
|
| 236 |
for artifact, replacement in artifacts.items():
|
| 237 |
text = sub(artifact, replacement, text) # type: ignore
|
| 238 |
|
| 239 |
-
|
|
|
|
|
|
|
| 240 |
|
| 241 |
|
| 242 |
def __call__(self, *args, **kwargs):
|
|
|
|
| 203 |
top_p=top_p,
|
| 204 |
top_k=top_k,
|
| 205 |
num_return_sequences=1,
|
| 206 |
+
# max_length=self.pipeline.tokenizer.model_max_length, # type: ignore
|
| 207 |
do_sample=True,
|
| 208 |
return_full_text=False,
|
| 209 |
streamer=TextStreamer(self.pipeline.tokenizer, # type: ignore
|
| 210 |
skip_prompt=True,
|
| 211 |
skip_special_tokens=True
|
| 212 |
),
|
| 213 |
+
max_new_tokens=1024,
|
| 214 |
)
|
| 215 |
|
| 216 |
if not stream:
|
|
|
|
| 219 |
def generate(self, image: Image.Image, **generate_kwargs):
|
| 220 |
prompt = "Assistant helps to write down the TikZ code for the user's image. USER: <image>\nWrite down the TikZ code to draw the diagram shown in the lol. ASSISTANT:"
|
| 221 |
tokenizer = self.pipeline.tokenizer
|
| 222 |
+
print('starting generation')
|
| 223 |
text = self.pipeline(image, prompt=prompt, generate_kwargs=(self.default_kwargs | generate_kwargs))[0]["generated_text"] # type: ignore
|
| 224 |
|
| 225 |
+
print('text generated: ', text) # TODO: remove
|
| 226 |
+
|
| 227 |
if self.clean_up_output:
|
| 228 |
for token in reversed(tokenizer.tokenize(prompt)): # type: ignore
|
| 229 |
# remove leading characters because skip_special_tokens in pipeline
|
|
|
|
| 240 |
for artifact, replacement in artifacts.items():
|
| 241 |
text = sub(artifact, replacement, text) # type: ignore
|
| 242 |
|
| 243 |
+
print('cleaned text: ', text)
|
| 244 |
+
|
| 245 |
+
return TikzDocument(text)
|
| 246 |
|
| 247 |
|
| 248 |
def __call__(self, *args, **kwargs):
|