Hadiil commited on
Commit
54f3bcb
·
verified ·
1 Parent(s): 8f889f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -34
app.py CHANGED
@@ -1,17 +1,13 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, WebSocket, Request
 
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse
4
- from fastapi_cache import FastAPICache
5
- from fastapi_cache.backends.redis import RedisBackend
6
- from fastapi_cache.decorator import cache
7
  from transformers import pipeline
8
  import logging
9
  from PIL import Image
10
  import io
11
  from docx import Document
12
  import fitz # PyMuPDF
13
- from pydantic import BaseModel
14
- import asyncio
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
@@ -22,47 +18,80 @@ app = FastAPI()
22
  # Serve static files (HTML, CSS, JS)
23
  app.mount("/static", StaticFiles(directory="static"), name="static")
24
 
25
- # Load AI models
26
  multimodal_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
27
- text_pipeline = pipeline("text2text-generation", model="t5-small")
28
 
29
- # Initialize Redis cache
30
- @app.on_event("startup")
31
- async def startup():
32
- FastAPICache.init(RedisBackend("redis://localhost:6379"))
33
 
34
- # Root endpoint
35
  @app.get("/")
36
  def read_root():
 
37
  return RedirectResponse(url="/static/index.html")
38
 
39
- # Summarization endpoint
40
  @app.post("/summarize")
41
- @cache(expire=300)
42
- async def summarize_text(file: UploadFile = File(None), text: str = Form(None)):
43
- if file:
44
- if not file.filename.endswith((".pdf", ".docx")):
45
- raise HTTPException(status_code=400, detail="Unsupported file format. Please upload a PDF or DOCX file.")
46
- text = await extract_text_from_file(file)
47
- elif not text:
48
- raise HTTPException(status_code=400, detail="No file or text provided")
49
-
50
  try:
 
 
 
 
51
  summary = text_pipeline(f"summarize: {text}", max_length=100)
52
- logger.info(f"Generated summary: {summary[0]['summary_text']}")
53
- return {"summary": summary[0]['summary_text']}
54
  except Exception as e:
55
  logger.error(f"Error during summarization: {e}")
56
- raise HTTPException(status_code=500, detail="An error occurred while processing your request. Please try again.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # WebSocket for real-time updates
59
- @app.websocket("/ws")
60
- async def websocket_endpoint(websocket: WebSocket):
61
- await websocket.accept()
62
- while True:
63
- data = await websocket.receive_text()
64
- summary = text_pipeline(f"summarize: {data}", max_length=100)
65
- await websocket.send_text(summary[0]['summary_text'])
66
 
67
  # Helper function to extract text from files
68
  async def extract_text_from_file(file: UploadFile):
 
1
+ import os
2
+ from fastapi import FastAPI, UploadFile, File, HTTPException
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.responses import RedirectResponse
 
 
 
5
  from transformers import pipeline
6
  import logging
7
  from PIL import Image
8
  import io
9
  from docx import Document
10
  import fitz # PyMuPDF
 
 
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
 
18
  # Serve static files (HTML, CSS, JS)
19
  app.mount("/static", StaticFiles(directory="static"), name="static")
20
 
21
+ # Load a multimodal model for image captioning and visual question answering
22
  multimodal_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
 
23
 
24
+ # Load a text-based model for summarization and text question answering
25
+ text_pipeline = pipeline("text2text-generation", model="t5-small")
 
 
26
 
 
27
  @app.get("/")
28
  def read_root():
29
+ # Redirect to the static HTML file
30
  return RedirectResponse(url="/static/index.html")
31
 
 
32
  @app.post("/summarize")
33
+ async def summarize_text(file: UploadFile = File(...)):
34
+ logger.info(f"Received document for summarization: {file.filename}")
 
 
 
 
 
 
 
35
  try:
36
+ # Extract text from the document
37
+ text = await extract_text_from_file(file)
38
+
39
+ # Use the text pipeline to summarize the text
40
  summary = text_pipeline(f"summarize: {text}", max_length=100)
41
+ logger.info(f"Generated summary: {summary[0]['generated_text']}")
42
+ return {"summary": summary[0]['generated_text']}
43
  except Exception as e:
44
  logger.error(f"Error during summarization: {e}")
45
+ raise HTTPException(status_code=500, detail=str(e))
46
+
47
+ @app.post("/caption")
48
+ async def caption_image(file: UploadFile = File(...)):
49
+ logger.info(f"Received image for captioning: {file.filename}")
50
+ try:
51
+ # Read the image file
52
+ image_data = await file.read()
53
+ image = Image.open(io.BytesIO(image_data))
54
+
55
+ # Use the multimodal pipeline to generate a caption for the image
56
+ caption = multimodal_pipeline(image)
57
+ logger.info(f"Generated caption: {caption[0]['generated_text']}")
58
+ return {"caption": caption[0]['generated_text']}
59
+ except Exception as e:
60
+ logger.error(f"Error during image captioning: {e}")
61
+ raise HTTPException(status_code=500, detail=str(e))
62
+
63
+ @app.post("/answer")
64
+ async def answer_question(file: UploadFile = File(...), question: str = ""):
65
+ logger.info(f"Received document for question answering: {file.filename}")
66
+ logger.info(f"Received question: {question}")
67
+ try:
68
+ # Extract text from the document
69
+ text = await extract_text_from_file(file)
70
+
71
+ # Use the text pipeline to answer the question
72
+ answer = text_pipeline(f"question: {question} context: {text}")
73
+ logger.info(f"Generated answer: {answer[0]['generated_text']}")
74
+ return {"answer": answer[0]['generated_text']}
75
+ except Exception as e:
76
+ logger.error(f"Error during question answering: {e}")
77
+ raise HTTPException(status_code=500, detail=str(e))
78
+
79
+ @app.post("/vqa")
80
+ async def visual_question_answering(file: UploadFile = File(...), question: str = ""):
81
+ logger.info(f"Received image for visual question answering: {file.filename}")
82
+ logger.info(f"Received question: {question}")
83
+ try:
84
+ # Read the image file
85
+ image_data = await file.read()
86
+ image = Image.open(io.BytesIO(image_data))
87
 
88
+ # Use the multimodal pipeline to answer the question about the image
89
+ answer = multimodal_pipeline(image, question=question)
90
+ logger.info(f"Generated answer: {answer[0]['generated_text']}")
91
+ return {"answer": answer[0]['generated_text']}
92
+ except Exception as e:
93
+ logger.error(f"Error during visual question answering: {e}")
94
+ raise HTTPException(status_code=500, detail=str(e))
 
95
 
96
  # Helper function to extract text from files
97
  async def extract_text_from_file(file: UploadFile):