Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends | |
| from typing import Optional | |
| from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from huggingface_hub import InferenceClient | |
| from pydantic import BaseModel, ConfigDict | |
| import os | |
| from base64 import b64encode | |
| from io import BytesIO | |
| from PIL import Image, ImageEnhance | |
| import logging | |
| import pytesseract | |
| import time | |
| import subprocess | |
| subprocess.Popen(["python", "main.py"]) | |
| # Set Tesseract CMD path for Windows | |
| #pytesseract.pytesseract.tesseract_cmd = r"F:\Python-files\tesseract\tesseract.exe" | |
| app = FastAPI() | |
| # Configure logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| # Default model | |
| DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| # Initialize Jinja2 templates | |
| templates = Jinja2Templates(directory="templates") | |
| class TextRequest(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| query: str | |
| stream: bool = False | |
| model_name: Optional[str] = None | |
| class ImageTextRequest(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| query: str | |
| stream: bool = False | |
| model_name: Optional[str] = None | |
| def as_form( | |
| cls, | |
| query: str = Form(...), | |
| stream: bool = Form(False), | |
| model_name: Optional[str] = Form(None), | |
| image: UploadFile = File(...) # Make image required for i2t2t | |
| ): | |
| return cls( | |
| query=query, | |
| stream=stream, | |
| model_name=model_name | |
| ), image | |
| def get_client(model_name: Optional[str] = None): | |
| """Get inference client for specified model or default model""" | |
| try: | |
| model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL | |
| return InferenceClient( | |
| model=model_path | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Error initializing model {model_path}: {str(e)}" | |
| ) | |
| def generate_text_response(query: str, model_name: Optional[str] = None): | |
| messages = [{ | |
| "role": "user", | |
| "content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}" | |
| }] | |
| try: | |
| client = get_client(model_name) | |
| for message in client.chat_completion( | |
| messages, | |
| max_tokens=2048, | |
| stream=True | |
| ): | |
| token = message.choices[0].delta.content | |
| yield token | |
| except Exception as e: | |
| yield f"Error generating response: {str(e)}" | |
| def generate_image_text_response(query: str, image_data: str, model_name: Optional[str] = None): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/*;base64,{image_data}"}} | |
| ] | |
| } | |
| ] | |
| logger.debug(f"Messages sent to API: {messages}") | |
| try: | |
| client = get_client(model_name) | |
| for message in client.chat_completion(messages, max_tokens=2048, stream=True): | |
| logger.debug(f"Received message chunk: {message}") | |
| token = message.choices[0].delta.content | |
| yield token | |
| except Exception as e: | |
| logger.error(f"Error in generate_image_text_response: {str(e)}") | |
| yield f"Error generating response: {str(e)}" | |
| def preprocess_image(img): | |
| """Enhance image for better OCR results""" | |
| # Convert to grayscale | |
| img = img.convert('L') | |
| # Enhance contrast | |
| enhancer = ImageEnhance.Contrast(img) | |
| img = enhancer.enhance(2.0) | |
| # Enhance sharpness | |
| enhancer = ImageEnhance.Sharpness(img) | |
| img = enhancer.enhance(1.5) | |
| return img | |
| async def root(): | |
| return {"message": "Welcome to FastAPI server!"} | |
| async def text_to_text(request: TextRequest): | |
| try: | |
| if request.stream: | |
| return StreamingResponse( | |
| generate_text_response(request.query, request.model_name), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| response = "" | |
| for chunk in generate_text_response(request.query, request.model_name): | |
| response += chunk | |
| return {"response": response} | |
| except Exception as e: | |
| logger.error(f"Error in /t2t endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def image_text_to_text(form_data: tuple[ImageTextRequest, UploadFile] = Depends(ImageTextRequest.as_form)): | |
| form, image = form_data | |
| try: | |
| # Process image | |
| contents = await image.read() | |
| try: | |
| logger.debug("Attempting to open image") | |
| img = Image.open(BytesIO(contents)) | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| buffer = BytesIO() | |
| img.save(buffer, format="PNG") | |
| image_data = b64encode(buffer.getvalue()).decode('utf-8') | |
| logger.debug("Image processed and encoded to base64") | |
| except Exception as img_error: | |
| logger.error(f"Error processing image: {str(img_error)}") | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Error processing image: {str(img_error)}" | |
| ) | |
| if form.stream: | |
| return StreamingResponse( | |
| generate_image_text_response(form.query, image_data, form.model_name), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| response = "" | |
| for chunk in generate_image_text_response(form.query, image_data, form.model_name): | |
| response += chunk | |
| return {"response": response} | |
| except Exception as e: | |
| logger.error(f"Error in /i2t2t endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def ocr_endpoint(image: UploadFile = File(...)): | |
| try: | |
| # Read and process the image | |
| contents = await image.read() | |
| img = Image.open(BytesIO(contents)) | |
| # Preprocess the image | |
| img = preprocess_image(img) | |
| # Perform OCR with timeout and retries | |
| max_retries = 3 | |
| text = "" | |
| for attempt in range(max_retries): | |
| try: | |
| text = pytesseract.image_to_string( | |
| img, | |
| timeout=30, # 30 second timeout | |
| config='--oem 3 --psm 6' | |
| ) | |
| break | |
| except Exception as e: | |
| if attempt == max_retries - 1: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error extracting text: {str(e)}" | |
| ) | |
| time.sleep(1) # Wait before retry | |
| return {"text": text} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error processing image: {str(e)}" | |
| ) | |
| async def api_guide(): | |
| html_content = ''' | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>API Documentation</title> | |
| <link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet"> | |
| <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/themes/prism-tomorrow.min.css"> | |
| <style> | |
| .copy-button { | |
| position: absolute; | |
| top: 8px; | |
| right: 8px; | |
| padding: 4px 8px; | |
| background: #2d3748; | |
| border: 1px solid #4a5568; | |
| border-radius: 4px; | |
| color: #cbd5e0; | |
| font-size: 12px; | |
| cursor: pointer; | |
| transition: all 0.2s; | |
| } | |
| .copy-button:hover { | |
| background: #4a5568; | |
| } | |
| .code-block { | |
| position: relative; | |
| margin: 1rem 0; | |
| } | |
| .endpoint-card { | |
| background: #1a202c; | |
| border-radius: 8px; | |
| margin-bottom: 2rem; | |
| padding: 1.5rem; | |
| } | |
| .language-tab { | |
| cursor: pointer; | |
| padding: 0.5rem 1rem; | |
| border-radius: 4px 4px 0 0; | |
| } | |
| .language-tab.active { | |
| background: #2d3748; | |
| color: #fff; | |
| } | |
| </style> | |
| </head> | |
| <body class="bg-gray-900 text-gray-100 min-h-screen p-8"> | |
| <div class="max-w-6xl mx-auto"> | |
| <h1 class="text-4xl font-bold mb-8">API Documentation</h1> | |
| <!-- T2T Endpoint --> | |
| <div class="endpoint-card"> | |
| <h2 class="text-2xl font-semibold mb-4">Text-to-Text Endpoint</h2> | |
| <p class="mb-4 text-gray-400">Endpoint for general text queries</p> | |
| <p class="mb-2 text-gray-300"><span class="font-mono bg-gray-800 px-2 py-1 rounded">POST /t2t</span></p> | |
| <div class="code-block"> | |
| <div class="flex mb-2"> | |
| <div class="language-tab active" data-lang="curl">cURL</div> | |
| <div class="language-tab" data-lang="python">Python</div> | |
| <div class="language-tab" data-lang="javascript">JavaScript</div> | |
| <div class="language-tab" data-lang="node">Node.js</div> | |
| </div> | |
| <pre><code class="language-bash">curl -X POST "http://localhost:8000/t2t" \ | |
| -H "Content-Type: application/json" \ | |
| -d '{"query": "What is FastAPI?", "stream": false}'</code></pre> | |
| <button class="copy-button">Copy</button> | |
| </div> | |
| </div> | |
| <!-- I2T2T Endpoint --> | |
| <div class="endpoint-card"> | |
| <h2 class="text-2xl font-semibold mb-4">Image and Text to Text Endpoint</h2> | |
| <p class="mb-4 text-gray-400">Endpoint for queries about images</p> | |
| <p class="mb-2 text-gray-300"><span class="font-mono bg-gray-800 px-2 py-1 rounded">POST /i2t2t</span></p> | |
| <div class="code-block"> | |
| <div class="flex mb-2"> | |
| <div class="language-tab active" data-lang="curl">cURL</div> | |
| <div class="language-tab" data-lang="python">Python</div> | |
| <div class="language-tab" data-lang="javascript">JavaScript</div> | |
| <div class="language-tab" data-lang="node">Node.js</div> | |
| </div> | |
| <pre><code class="language-bash">curl -X POST "http://localhost:8000/i2t2t" \ | |
| -F "query=Describe this image" \ | |
| -F "stream=false" \ | |
| -F "image=@/path/to/your/image.jpg"</code></pre> | |
| <button class="copy-button">Copy</button> | |
| </div> | |
| </div> | |
| <!-- TES Endpoint --> | |
| <div class="endpoint-card"> | |
| <h2 class="text-2xl font-semibold mb-4">OCR Endpoint</h2> | |
| <p class="mb-4 text-gray-400">Extract text from images using OCR</p> | |
| <p class="mb-2 text-gray-300"><span class="font-mono bg-gray-800 px-2 py-1 rounded">POST /tes</span></p> | |
| <div class="code-block"> | |
| <div class="flex mb-2"> | |
| <div class="language-tab active" data-lang="curl">cURL</div> | |
| <div class="language-tab" data-lang="python">Python</div> | |
| <div class="language-tab" data-lang="javascript">JavaScript</div> | |
| <div class="language-tab" data-lang="node">Node.js</div> | |
| </div> | |
| <pre><code class="language-bash">curl -X POST "http://localhost:8000/tes" \ | |
| -F "image=@/path/to/your/image.jpg"</code></pre> | |
| <button class="copy-button">Copy</button> | |
| </div> | |
| </div> | |
| </div> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/prism.min.js"></script> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/components/prism-python.min.js"></script> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/components/prism-javascript.min.js"></script> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/components/prism-bash.min.js"></script> | |
| <script> | |
| const codeExamples = { | |
| 't2t': { | |
| 'curl': `curl -X POST "http://localhost:8000/t2t" \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{"query": "What is FastAPI?", "stream": false}'`, | |
| 'python': `import requests | |
| url = "http://localhost:8000/t2t" | |
| payload = { | |
| "query": "What is FastAPI?", | |
| "stream": False | |
| } | |
| response = requests.post(url, json=payload) | |
| print(response.json())`, | |
| 'javascript': `// Using fetch | |
| fetch("http://localhost:8000/t2t", { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| }, | |
| body: JSON.stringify({ | |
| query: "What is FastAPI?", | |
| stream: false | |
| }) | |
| }) | |
| .then(response => response.json()) | |
| .then(data => console.log(data));`, | |
| 'node': `const axios = require('axios'); | |
| async function makeRequest() { | |
| try { | |
| const response = await axios.post('http://localhost:8000/t2t', { | |
| query: "What is FastAPI?", | |
| stream: false | |
| }); | |
| console.log(response.data); | |
| } catch (error) { | |
| console.error(error); | |
| } | |
| } | |
| makeRequest();` | |
| }, | |
| 'i2t2t': { | |
| 'curl': `curl -X POST "http://localhost:8000/i2t2t" \\ | |
| -F "query=Describe this image" \\ | |
| -F "stream=false" \\ | |
| -F "image=@/path/to/your/image.jpg"`, | |
| 'python': `import requests | |
| url = "http://localhost:8000/i2t2t" | |
| files = { | |
| 'image': ('image.jpg', open('path/to/image.jpg', 'rb')), | |
| } | |
| data = { | |
| 'query': 'Describe this image', | |
| 'stream': 'false' | |
| } | |
| response = requests.post(url, files=files, data=data) | |
| print(response.json())`, | |
| 'javascript': `const formData = new FormData(); | |
| formData.append('image', imageFile); | |
| formData.append('query', 'Describe this image'); | |
| formData.append('stream', 'false'); | |
| fetch("http://localhost:8000/i2t2t", { | |
| method: "POST", | |
| body: formData | |
| }) | |
| .then(response => response.json()) | |
| .then(data => console.log(data));`, | |
| 'node': `const axios = require('axios'); | |
| const FormData = require('form-data'); | |
| const fs = require('fs'); | |
| async function makeRequest() { | |
| try { | |
| const formData = new FormData(); | |
| formData.append('image', fs.createReadStream('path/to/image.jpg')); | |
| formData.append('query', 'Describe this image'); | |
| formData.append('stream', 'false'); | |
| const response = await axios.post('http://localhost:8000/i2t2t', formData, { | |
| headers: formData.getHeaders() | |
| }); | |
| console.log(response.data); | |
| } catch (error) { | |
| console.error(error); | |
| } | |
| } | |
| makeRequest();` | |
| }, | |
| 'tes': { | |
| 'curl': `curl -X POST "http://localhost:8000/tes" \\ | |
| -F "image=@/path/to/your/image.jpg"`, | |
| 'python': `import requests | |
| url = "http://localhost:8000/tes" | |
| files = { | |
| 'image': ('image.jpg', open('path/to/image.jpg', 'rb')) | |
| } | |
| response = requests.post(url, files=files) | |
| print(response.json())`, | |
| 'javascript': `const formData = new FormData(); | |
| formData.append('image', imageFile); | |
| fetch("http://localhost:8000/tes", { | |
| method: "POST", | |
| body: formData | |
| }) | |
| .then(response => response.json()) | |
| .then(data => console.log(data));`, | |
| 'node': `const axios = require('axios'); | |
| const FormData = require('form-data'); | |
| const fs = require('fs'); | |
| async function makeRequest() { | |
| try { | |
| const formData = new FormData(); | |
| formData.append('image', fs.createReadStream('path/to/image.jpg')); | |
| const response = await axios.post('http://localhost:8000/tes', formData, { | |
| headers: formData.getHeaders() | |
| }); | |
| console.log(response.data); | |
| } catch (error) { | |
| console.error(error); | |
| } | |
| } | |
| makeRequest();` | |
| } | |
| }; | |
| // Handle language tab switching | |
| document.querySelectorAll('.language-tab').forEach(tab => { | |
| tab.addEventListener('click', () => { | |
| const lang = tab.dataset.lang; | |
| const codeBlock = tab.closest('.endpoint-card'); | |
| const endpoint = codeBlock.querySelector('h2').textContent.toLowerCase().includes('ocr') ? 'tes' : | |
| codeBlock.querySelector('h2').textContent.toLowerCase().includes('image') ? 'i2t2t' : 't2t'; | |
| // Update active tab | |
| codeBlock.querySelectorAll('.language-tab').forEach(t => t.classList.remove('active')); | |
| tab.classList.add('active'); | |
| // Update code content | |
| const code = codeBlock.querySelector('code'); | |
| code.textContent = codeExamples[endpoint][lang]; | |
| code.className = `language-${lang === 'curl' ? 'bash' : lang}`; | |
| Prism.highlightElement(code); | |
| }); | |
| }); | |
| // Handle copy buttons | |
| document.querySelectorAll('.copy-button').forEach(button => { | |
| button.addEventListener('click', () => { | |
| const code = button.previousElementSibling.textContent; | |
| navigator.clipboard.writeText(code); | |
| // Show feedback | |
| const originalText = button.textContent; | |
| button.textContent = 'Copied!'; | |
| setTimeout(() => { | |
| button.textContent = originalText; | |
| }, 2000); | |
| }); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| ''' | |
| return HTMLResponse(content=html_content) | |