Spaces:
Paused
Paused
| import json | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from paddleocr import PaddleOCR | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| from passporteye import read_mrz | |
| from pydantic.v1 import BaseModel as v1BaseModel | |
| from pydantic.v1 import Field | |
| from pydantic import BaseModel | |
| from typing import Any, Optional, Dict, List | |
| from huggingface_hub import InferenceClient | |
| from langchain.llms.base import LLM | |
| import os | |
| HF_token = os.getenv("apiToken") | |
| model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| hf_token = HF_token | |
| kwargs = {"max_new_tokens":500, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True} | |
| class KwArgsModel(v1BaseModel): | |
| kwargs: Dict[str, Any] = Field(default_factory=dict) | |
| class CustomInferenceClient(LLM, KwArgsModel): | |
| model_name: str | |
| inference_client: InferenceClient | |
| def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): | |
| inference_client = InferenceClient(model=model_name, token=hf_token) | |
| super().__init__( | |
| model_name=model_name, | |
| hf_token=hf_token, | |
| kwargs=kwargs, | |
| inference_client=inference_client | |
| ) | |
| def _call( | |
| self, | |
| prompt: str, | |
| stop: Optional[List[str]] = None | |
| ) -> str: | |
| if stop is not None: | |
| raise ValueError("stop kwargs are not permitted.") | |
| response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False) | |
| response = ''.join(response_gen) | |
| return response | |
| def _llm_type(self) -> str: | |
| return "custom" | |
| def _identifying_params(self) -> dict: | |
| return {"model_name": self.model_name} | |
| app = FastAPI(title="Passport Recognition API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| ocr = PaddleOCR(use_angle_cls=True, lang='en') | |
| template = """below is poorly read ocr result of a passport. | |
| OCR Result: | |
| {ocr_result} | |
| Fill the below catergories using the OCR Results. you can correct spellings and make other adujustments. Dates should be in 01-JAN-2000 format. | |
| "countryName": "", | |
| "dateOfBirth": "", | |
| "dateOfExpiry": "", | |
| "dateOfIssue": "", | |
| "documentNumber": "", | |
| "givenNames": "", | |
| "name": "", | |
| "surname": "", | |
| "mrz": "" | |
| json output: | |
| """ | |
| prompt = PromptTemplate(template=template, input_variables=["ocr_result"]) | |
| class MRZData(BaseModel): | |
| date_of_birth: str | |
| expiration_date: str | |
| type: str | |
| number: str | |
| names: str | |
| country: str | |
| check_number: str | |
| check_date_of_birth: str | |
| check_expiration_date: str | |
| check_composite: str | |
| check_personal_number: str | |
| valid_number: bool | |
| valid_date_of_birth: bool | |
| valid_expiration_date: bool | |
| valid_composite: bool | |
| valid_personal_number: bool | |
| method: str | |
| class OCRData(BaseModel): | |
| countryName: str | |
| dateOfBirth: str | |
| dateOfExpiry: str | |
| dateOfIssue: str | |
| documentNumber: str | |
| givenNames: str | |
| name: str | |
| surname: str | |
| mrz: str | |
| class ResponseData(BaseModel): | |
| documentName: str | |
| errorCode: int | |
| mrz: Optional[MRZData] | |
| ocr: Optional[OCRData] | |
| status: str | |
| def create_response_data(mrz, ocr_data): | |
| if not mrz and not ocr_data: | |
| return ResponseData( | |
| documentName="Passport", | |
| errorCode=1, | |
| mrz=None, | |
| ocr=None, | |
| status="No MRZ or OCR data available" | |
| ) | |
| elif not mrz: | |
| return ResponseData( | |
| documentName="Passport", | |
| errorCode=2, | |
| mrz=None, | |
| ocr=OCRData(**ocr_data), | |
| status="PassportEYE did not find an MRZ" | |
| ) | |
| elif not ocr_data: | |
| return ResponseData( | |
| documentName="Passport", | |
| errorCode=3, | |
| mrz=MRZData(**mrz), | |
| ocr=None, | |
| status="OCR result not available" | |
| ) | |
| else: | |
| return ResponseData( | |
| documentName="Passport", | |
| errorCode=0, | |
| mrz=MRZData(**mrz), | |
| ocr=OCRData(**ocr_data), | |
| status="ok" | |
| ) | |
| async def recognize_passport(image: UploadFile = File(...)): | |
| """Passport information extraction from a provided image file.""" | |
| try: | |
| image_bytes = await image.read() | |
| mrz = read_mrz(image_bytes) | |
| img_path = 'image.jpg' | |
| with open(img_path, 'wb') as f: | |
| f.write(image_bytes) | |
| result = ocr.ocr(img_path, cls=True) | |
| json_result = [] | |
| for idx in range(len(result)): | |
| res = result[idx] | |
| for line in res: | |
| coordinates, text_with_confidence = line | |
| text, confidence = text_with_confidence | |
| json_result.append({ | |
| 'coordinates': coordinates, | |
| 'text': text, | |
| 'confidence': confidence | |
| }) | |
| llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs) | |
| llm_chain = LLMChain(prompt=prompt, llm=llm) | |
| response_str = llm_chain.run(ocr_result=json_result) | |
| response_str = response_str.rstrip("</s>") | |
| #print(response_str) | |
| ocr_data = json.loads(response_str) | |
| return create_response_data(mrz.to_dict() if mrz else None, ocr_data) | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Internal server error: {str(e)}" | |
| ) from e |