Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends | |
| from typing import Optional | |
| from fastapi.responses import StreamingResponse | |
| from huggingface_hub import InferenceClient | |
| from pydantic import BaseModel, ConfigDict | |
| import os | |
| from base64 import b64encode | |
| from io import BytesIO | |
| from PIL import Image, ImageEnhance | |
| import logging | |
| import pytesseract | |
| import time | |
| app = FastAPI() | |
| # Configure logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| # Default model | |
| DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| class TextRequest(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| query: str | |
| stream: bool = False | |
| model_name: Optional[str] = None | |
| class ImageTextRequest(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| query: str | |
| stream: bool = False | |
| model_name: Optional[str] = None | |
| def as_form( | |
| cls, | |
| query: str = Form(...), | |
| stream: bool = Form(False), | |
| model_name: Optional[str] = Form(None), | |
| image: UploadFile = File(...) # Make image required for i2t2t | |
| ): | |
| return cls( | |
| query=query, | |
| stream=stream, | |
| model_name=model_name | |
| ), image | |
| def get_client(model_name: Optional[str] = None): | |
| """Get inference client for specified model or default model""" | |
| try: | |
| model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL | |
| return InferenceClient( | |
| model=model_path | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Error initializing model {model_path}: {str(e)}" | |
| ) | |
| def generate_text_response(query: str, model_name: Optional[str] = None): | |
| messages = [{ | |
| "role": "user", | |
| "content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}" | |
| }] | |
| try: | |
| client = get_client(model_name) | |
| for message in client.chat_completion( | |
| messages, | |
| max_tokens=2048, | |
| stream=True | |
| ): | |
| token = message.choices[0].delta.content | |
| yield token | |
| except Exception as e: | |
| yield f"Error generating response: {str(e)}" | |
| def generate_image_text_response(query: str, image_data: str, model_name: Optional[str] = None): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/*;base64,{image_data}"}} | |
| ] | |
| } | |
| ] | |
| logger.debug(f"Messages sent to API: {messages}") | |
| try: | |
| client = get_client(model_name) | |
| for message in client.chat_completion(messages, max_tokens=2048, stream=True): | |
| logger.debug(f"Received message chunk: {message}") | |
| token = message.choices[0].delta.content | |
| yield token | |
| except Exception as e: | |
| logger.error(f"Error in generate_image_text_response: {str(e)}") | |
| yield f"Error generating response: {str(e)}" | |
| def preprocess_image(img): | |
| """Enhance image for better OCR results""" | |
| # Convert to grayscale | |
| img = img.convert('L') | |
| # Enhance contrast | |
| enhancer = ImageEnhance.Contrast(img) | |
| img = enhancer.enhance(2.0) | |
| # Enhance sharpness | |
| enhancer = ImageEnhance.Sharpness(img) | |
| img = enhancer.enhance(1.5) | |
| return img | |
| async def root(): | |
| return {"message": "Welcome to FastAPI server!"} | |
| async def text_to_text(request: TextRequest): | |
| try: | |
| if request.stream: | |
| return StreamingResponse( | |
| generate_text_response(request.query, request.model_name), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| response = "" | |
| for chunk in generate_text_response(request.query, request.model_name): | |
| response += chunk | |
| return {"response": response} | |
| except Exception as e: | |
| logger.error(f"Error in /t2t endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def image_text_to_text(form_data: tuple[ImageTextRequest, UploadFile] = Depends(ImageTextRequest.as_form)): | |
| form, image = form_data | |
| try: | |
| # Process image | |
| contents = await image.read() | |
| try: | |
| logger.debug("Attempting to open image") | |
| img = Image.open(BytesIO(contents)) | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| buffer = BytesIO() | |
| img.save(buffer, format="PNG") | |
| image_data = b64encode(buffer.getvalue()).decode('utf-8') | |
| logger.debug("Image processed and encoded to base64") | |
| except Exception as img_error: | |
| logger.error(f"Error processing image: {str(img_error)}") | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Error processing image: {str(img_error)}" | |
| ) | |
| if form.stream: | |
| return StreamingResponse( | |
| generate_image_text_response(form.query, image_data, form.model_name), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| response = "" | |
| for chunk in generate_image_text_response(form.query, image_data, form.model_name): | |
| response += chunk | |
| return {"response": response} | |
| except Exception as e: | |
| logger.error(f"Error in /i2t2t endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def ocr_endpoint(image: UploadFile = File(...)): | |
| try: | |
| # Read and process the image | |
| contents = await image.read() | |
| img = Image.open(BytesIO(contents)) | |
| # Preprocess the image | |
| img = preprocess_image(img) | |
| # Perform OCR with timeout and retries | |
| max_retries = 3 | |
| text = "" | |
| for attempt in range(max_retries): | |
| try: | |
| text = pytesseract.image_to_string( | |
| img, | |
| timeout=30, # 30 second timeout | |
| config='--oem 3 --psm 6' | |
| ) | |
| break | |
| except Exception as e: | |
| if attempt == max_retries - 1: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error extracting text: {str(e)}" | |
| ) | |
| time.sleep(1) # Wait before retry | |
| return {"text": text} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error processing image: {str(e)}" | |
| ) | |