Spaces:
Build error
Build error
Ilia Tambovtsev commited on
Commit ·
c1a7f36
1
Parent(s): c233ff9
style: fix typing, reformat
Browse filesstyle: black+isort
style: fix typing
- src/chains/pipelines.py +6 -9
- src/chains/prompts.py +7 -10
- src/processing/image_utlis.py +2 -2
src/chains/pipelines.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
-
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 2 |
-
from pydantic import BaseModel, ConfigDict, Field
|
| 3 |
-
from pathlib import Path
|
| 4 |
import json
|
| 5 |
import logging
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
from datetime import datetime
|
| 8 |
-
import
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
from langchain.chains.base import Chain
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
from src.chains.chains import (
|
| 14 |
LoadPageChain,
|
|
@@ -16,11 +16,8 @@ from src.chains.chains import (
|
|
| 16 |
ImageEncodeChain,
|
| 17 |
VisionAnalysisChain
|
| 18 |
)
|
| 19 |
-
|
| 20 |
from src.chains.prompts import BasePrompt, JsonH1AndGDPrompt
|
| 21 |
from src.config.navigator import Navigator
|
| 22 |
-
from src.chains.prompts import BasePrompt
|
| 23 |
-
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import logging
|
|
|
|
| 3 |
from datetime import datetime
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 6 |
|
| 7 |
+
import fitz
|
| 8 |
from langchain.chains.base import Chain
|
| 9 |
+
from langchain_openai.chat_models import ChatOpenAI
|
| 10 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
|
| 13 |
from src.chains.chains import (
|
| 14 |
LoadPageChain,
|
|
|
|
| 16 |
ImageEncodeChain,
|
| 17 |
VisionAnalysisChain
|
| 18 |
)
|
|
|
|
| 19 |
from src.chains.prompts import BasePrompt, JsonH1AndGDPrompt
|
| 20 |
from src.config.navigator import Navigator
|
|
|
|
|
|
|
| 21 |
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
src/chains/prompts.py
CHANGED
|
@@ -1,18 +1,15 @@
|
|
| 1 |
-
from typing import Dict, Any, Optional, Type, Union, TypeVar
|
| 2 |
-
from abc import ABC, abstractmethod
|
| 3 |
import logging
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
from pydantic import BaseModel
|
| 6 |
-
from langchain.prompts import ChatPromptTemplate
|
| 7 |
from langchain.output_parsers import PydanticOutputParser
|
| 8 |
-
from
|
| 9 |
-
from langchain_core.output_parsers import StrOutputParser
|
| 10 |
-
|
| 11 |
-
from textwrap import dedent
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
-
T = TypeVar("T")
|
| 16 |
|
| 17 |
|
| 18 |
class BasePrompt(ABC):
|
|
@@ -62,7 +59,7 @@ class BasePrompt(ABC):
|
|
| 62 |
return self._template
|
| 63 |
|
| 64 |
@abstractmethod
|
| 65 |
-
def parse(self, text: str) ->
|
| 66 |
"""Parse LLM output
|
| 67 |
|
| 68 |
Args:
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from textwrap import dedent
|
| 4 |
+
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
| 5 |
|
|
|
|
|
|
|
| 6 |
from langchain.output_parsers import PydanticOutputParser
|
| 7 |
+
from langchain.prompts import ChatPromptTemplate
|
| 8 |
+
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
| 9 |
+
from pydantic import BaseModel
|
|
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class BasePrompt(ABC):
|
|
|
|
| 59 |
return self._template
|
| 60 |
|
| 61 |
@abstractmethod
|
| 62 |
+
def parse(self, text: str) -> object:
|
| 63 |
"""Parse LLM output
|
| 64 |
|
| 65 |
Args:
|
src/processing/image_utlis.py
CHANGED
|
@@ -2,8 +2,8 @@ from PIL import Image
|
|
| 2 |
import base64
|
| 3 |
import io
|
| 4 |
|
| 5 |
-
def image2base64(pil_image: Image):
|
| 6 |
buffered = io.BytesIO()
|
| 7 |
pil_image.save(buffered, format="png")
|
| 8 |
img_str = base64.b64encode(buffered.getvalue())
|
| 9 |
-
return img_str.decode("utf-8")
|
|
|
|
| 2 |
import base64
|
| 3 |
import io
|
| 4 |
|
| 5 |
+
def image2base64(pil_image: Image.Image):
|
| 6 |
buffered = io.BytesIO()
|
| 7 |
pil_image.save(buffered, format="png")
|
| 8 |
img_str = base64.b64encode(buffered.getvalue())
|
| 9 |
+
return img_str.decode("utf-8")
|