google-fastapi / app.py
Hadiil's picture
Update app.py
3318917 verified
import os
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from transformers import pipeline
from PIL import Image
import io
import fitz
from docx import Document
import pandas as pd
import logging
from typing import Optional
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="AI Web Services",
description="API for various AI services including text summarization, image captioning, and document analysis",
version="1.0.0"
)
# Mount static files
os.makedirs("static", exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Load AI models
try:
logger.info("Loading AI models...")
image_pipeline = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
device="cpu",
use_fast=False # Explicitly set to avoid warning
)
text_pipeline = pipeline(
"text2text-generation",
model="t5-small",
device="cpu"
)
logger.info("Models loaded successfully")
except Exception as e:
logger.error(f"Model loading failed: {e}")
raise RuntimeError("Failed to initialize AI models")
# Helper function for text extraction
async def extract_text(file: UploadFile) -> str:
"""Extract text from PDF or DOCX files"""
try:
content = await file.read()
if file.filename.endswith(".pdf"):
with fitz.open(stream=content, filetype="pdf") as doc:
return " ".join(page.get_text() for page in doc)
elif file.filename.endswith(".docx"):
doc = Document(io.BytesIO(content))
return "\n".join(p.text for p in doc.paragraphs)
else:
raise ValueError("Unsupported file format")
except Exception as e:
logger.error(f"Text extraction failed: {e}")
raise HTTPException(400, f"Could not extract text: {e}")
# API Endpoints
@app.get("/", response_class=HTMLResponse)
async def home():
"""Serve the frontend interface"""
try:
with open("static/index.html") as f:
return f.read()
except FileNotFoundError:
return "<h1>Welcome to AI Web Services</h1><p>Frontend not found</p>"
except Exception as e:
logger.error(f"Failed to load frontend: {e}")
raise HTTPException(500, "Frontend loading failed")
@app.post("/api/summarize")
async def summarize(
file: Optional[UploadFile] = File(None),
text: Optional[str] = Form(None)
):
"""Summarize text or document"""
try:
if file:
text = await extract_text(file)
if not text:
raise HTTPException(400, "No text provided")
result = text_pipeline(f"summarize: {text}", max_length=150)
return {"summary": result[0]['generated_text']}
except HTTPException:
raise
except Exception as e:
logger.error(f"Summarization error: {e}")
raise HTTPException(500, "Summarization failed")
@app.post("/api/caption")
async def caption_image(file: UploadFile = File(...)):
"""Generate caption for image"""
try:
image = Image.open(io.BytesIO(await file.read()))
result = image_pipeline(image)
return {"caption": result[0]['generated_text']}
except Exception as e:
logger.error(f"Captioning error: {e}")
raise HTTPException(500, "Image captioning failed")
@app.post("/api/answer")
async def answer_question(
file: Optional[UploadFile] = File(None),
text: Optional[str] = Form(None),
question: str = Form(...)
):
"""Answer questions about text/document"""
try:
if file:
text = await extract_text(file)
if not text:
raise HTTPException(400, "No text provided")
result = text_pipeline(f"question: {question} context: {text}")
return {"answer": result[0]['generated_text']}
except HTTPException:
raise
except Exception as e:
logger.error(f"QA error: {e}")
raise HTTPException(500, "Question answering failed")
@app.post("/api/visualize")
async def generate_visualization(
file: UploadFile = File(...),
chart_type: str = Form("bar")
):
"""Generate visualization code for Excel data"""
try:
df = pd.read_excel(io.BytesIO(await file.read()))
if chart_type.lower() == "bar":
code = f"""import matplotlib.pyplot as plt
plt.bar(df['{df.columns[0]}'], df['{df.columns[1]}'])
plt.title('Bar Chart')
plt.show()"""
else:
code = f"""import seaborn as sns
sns.pairplot(df)
plt.title('Data Distribution')
plt.show()"""
return {
"code": code,
"columns": list(df.columns)
}
except Exception as e:
logger.error(f"Visualization error: {e}")
raise HTTPException(500, "Visualization code generation failed")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"models": {
"image_captioning": "loaded",
"text_generation": "loaded"
}
}
# Server initialization
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app:app", # Changed to string format for proper reload
host="0.0.0.0",
port=8000,
log_level="info",
reload=False # Disabled reload for direct execution
)