Hadiil commited on
Commit
30fa099
·
verified ·
1 Parent(s): f838108

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -127
app.py CHANGED
@@ -1,18 +1,15 @@
1
- import os
2
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.responses import RedirectResponse, JSONResponse
5
- from transformers import pipeline
6
  import logging
7
  from PIL import Image
8
  import io
9
  from docx import Document
10
  import fitz # PyMuPDF
11
  import pandas as pd
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
- import uuid
15
- from transformers import MarianMTModel, MarianTokenizer
16
  from fastapi.middleware.cors import CORSMiddleware
17
 
18
  # Configure logging
@@ -33,51 +30,64 @@ app.add_middleware(
33
  # Serve static files (HTML, CSS, JS)
34
  app.mount("/static", StaticFiles(directory="static"), name="static")
35
 
36
- # Load models
37
- multimodal_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", use_fast=True)
38
- text_pipeline = pipeline("text2text-generation", model="t5-small", use_fast=True)
39
  translation_models = {
40
  "fr": "Helsinki-NLP/opus-mt-en-fr",
41
  "es": "Helsinki-NLP/opus-mt-en-es",
42
  "de": "Helsinki-NLP/opus-mt-en-de"
43
  }
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @app.get("/")
46
  def read_root():
47
  return RedirectResponse(url="/static/index.html")
48
 
 
49
  @app.post("/summarize")
50
- async def summarize_text(
51
- file: UploadFile = File(None),
52
- text: str = Form(None)
53
- ):
54
- logger.info(f"Received request: file={file}, text={text}") # Debugging
55
-
56
- if file:
57
- logger.info(f"Received document for summarization: {file.filename}")
58
- try:
59
  text = await extract_text_from_file(file)
60
- except Exception as e:
61
- logger.error(f"Error extracting text from file: {e}")
62
- raise HTTPException(status_code=400, detail=str(e))
63
- elif text:
64
- logger.info("Received manual text for summarization")
65
- else:
66
- logger.error("No file or text provided") # Debugging
67
- raise HTTPException(status_code=400, detail="No file or text provided")
68
 
69
- try:
70
  summary = text_pipeline(f"summarize: {text}", max_length=100)
71
- logger.info(f"Generated summary: {summary[0]['generated_text']}")
72
  return {"summary": summary[0]['generated_text']}
73
  except Exception as e:
74
  logger.error(f"Error during summarization: {e}")
75
- raise HTTPException(status_code=500, detail=str(e))
76
 
 
77
  @app.post("/caption")
78
  async def caption_image(file: UploadFile = File(...)):
79
- logger.info(f"Received image for captioning: {file.filename}")
80
  try:
 
81
  image_data = await file.read()
82
  image = Image.open(io.BytesIO(image_data))
83
 
@@ -85,71 +95,55 @@ async def caption_image(file: UploadFile = File(...)):
85
  if image.format not in ["JPEG", "PNG"]:
86
  raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
87
 
 
88
  caption = multimodal_pipeline(image)
89
- logger.info(f"Generated caption: {caption[0]['generated_text']}")
90
  return {"caption": caption[0]['generated_text']}
91
  except Exception as e:
92
  logger.error(f"Error during image captioning: {e}")
93
  raise HTTPException(status_code=400, detail=str(e))
94
 
 
95
  @app.post("/translate")
96
- async def translate_document(
97
- file: UploadFile = File(...),
98
- target_language: str = Form(...)
99
- ):
100
- logger.info(f"Received document for translation: {file.filename}")
101
- logger.info(f"Target language: {target_language}")
102
-
103
  try:
104
- text = await extract_text_from_file(file)
105
-
106
- if target_language in translation_models:
107
- model_name = translation_models[target_language]
108
- else:
109
- model_name = "Helsinki-NLP/opus-mt-en-de" # Default to German
110
-
111
- tokenizer = MarianTokenizer.from_pretrained(model_name)
112
- model = MarianMTModel.from_pretrained(model_name)
113
-
114
- translated = model.generate(**tokenizer(text, return_tensors="pt", truncation=True))
115
- translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
116
 
117
- return {"translated_text": translated_text}
 
 
118
  except Exception as e:
119
- logger.error(f"Error during document translation: {e}")
120
- raise HTTPException(status_code=500, detail=str(e))
121
 
 
122
  @app.post("/answer")
123
- async def answer_question(
124
- file: UploadFile = File(None),
125
- text: str = Form(None),
126
- question: str = Form(...)
127
- ):
128
- if file:
129
- logger.info(f"Received document for question answering: {file.filename}")
130
- try:
131
  text = await extract_text_from_file(file)
132
- except Exception as e:
133
- logger.error(f"Error extracting text from file: {e}")
134
- raise HTTPException(status_code=400, detail=str(e))
135
- elif text:
136
- logger.info("Received manual text for question answering")
137
- else:
138
- raise HTTPException(status_code=400, detail="No file or text provided")
139
 
140
- try:
141
  answer = text_pipeline(f"question: {question} context: {text}")
142
- logger.info(f"Generated answer: {answer[0]['generated_text']}")
143
  return {"answer": answer[0]['generated_text']}
144
  except Exception as e:
145
  logger.error(f"Error during question answering: {e}")
146
- raise HTTPException(status_code=500, detail=str(e))
147
 
 
148
  @app.post("/vqa")
149
  async def visual_question_answering(file: UploadFile = File(...), question: str = Form(...)):
150
- logger.info(f"Received image for visual question answering: {file.filename}")
151
- logger.info(f"Received question: {question}")
152
  try:
 
 
153
  image_data = await file.read()
154
  image = Image.open(io.BytesIO(image_data))
155
 
@@ -157,22 +151,19 @@ async def visual_question_answering(file: UploadFile = File(...), question: str
157
  if image.format not in ["JPEG", "PNG"]:
158
  raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
159
 
 
160
  answer = multimodal_pipeline(image, question=question)
161
- logger.info(f"Generated answer: {answer[0]['generated_text']}")
162
  return {"answer": answer[0]['generated_text']}
163
  except Exception as e:
164
  logger.error(f"Error during visual question answering: {e}")
165
  raise HTTPException(status_code=400, detail=str(e))
166
 
 
167
  @app.post("/visualize")
168
- async def visualize_data(
169
- file: UploadFile = File(...),
170
- request: str = Form(...)
171
- ):
172
- logger.info(f"Received Excel file for visualization: {file.filename}")
173
- logger.info(f"Received visualization request: {request}")
174
-
175
  try:
 
 
176
  df = pd.read_excel(io.BytesIO(await file.read()))
177
 
178
  if "bar" in request.lower():
@@ -200,64 +191,33 @@ sns.pairplot(df)
200
  plt.show()
201
  """
202
 
203
- code_filename = f"visualization_{uuid.uuid4()}.py"
204
- with open(code_filename, "w") as f:
205
- f.write(code)
206
-
207
- return {"code": code, "filename": code_filename}
208
  except Exception as e:
209
  logger.error(f"Error during visualization code generation: {e}")
210
- raise HTTPException(status_code=500, detail=str(e))
211
 
 
212
  async def extract_text_from_file(file: UploadFile):
213
  try:
214
  file_content = await file.read()
215
- if not file_content:
216
- logger.error("Uploaded file is empty.")
217
- raise ValueError("Uploaded file is empty.")
218
-
219
- # Check file size (e.g., limit to 10MB)
220
- if len(file_content) > 10 * 1024 * 1024: # 10MB
221
- logger.error("File size exceeds the limit (10MB).")
222
- raise ValueError("File size exceeds the limit (10MB).")
223
-
224
- # Check file type
225
- if not file.filename.lower().endswith((".pdf", ".docx", ".txt")):
226
- logger.error(f"Unsupported files format: {file.filename}")
227
- raise ValueError("Unsupported file format. Please upload a PDF, DOCX, or TXT file.")
228
-
229
  if file.filename.endswith(".pdf"):
230
- try:
231
- # Log the first few bytes of the file for debugging
232
- logger.info(f"First 100 bytes of the file: {file_content[:100]}")
233
-
234
- # Attempt to open the PDF
235
- doc = fitz.open(stream=file_content, filetype="pdf")
236
- text = ""
237
- for page in doc:
238
- text += page.get_text()
239
- return text
240
- except Exception as e:
241
- logger.error(f"Error reading PDF file: {e}")
242
- raise ValueError("Failed to read PDF file. It might be corrupted or not a valid PDF.")
243
  elif file.filename.endswith(".docx"):
244
- try:
245
- doc = Document(io.BytesIO(file_content))
246
- text = "\n".join([para.text for para in doc.paragraphs])
247
- return text
248
- except Exception as e:
249
- logger.error(f"Error reading DOCX file: {e}")
250
- raise ValueError("Failed to read DOCX file. It might be corrupted or not a valid DOCX.")
251
  elif file.filename.endswith(".txt"):
252
- try:
253
- return file_content.decode("utf-8")
254
- except Exception as e:
255
- logger.error(f"Error reading TXT file: {e}")
256
- raise ValueError("Failed to read TXT file. It might be corrupted or not a valid TXT.")
257
  except Exception as e:
258
  logger.error(f"Error extracting text from file: {e}")
259
- raise HTTPException(status_code=400, detail=str(e))
260
 
 
261
  if __name__ == "__main__":
262
  import uvicorn
263
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse, JSONResponse
4
+ from transformers import pipeline, MarianMTModel, MarianTokenizer
5
  import logging
6
  from PIL import Image
7
  import io
8
  from docx import Document
9
  import fitz # PyMuPDF
10
  import pandas as pd
11
+ from tenacity import retry, stop_after_attempt, wait_exponential
12
+ from functools import lru_cache
 
 
13
  from fastapi.middleware.cors import CORSMiddleware
14
 
15
  # Configure logging
 
30
  # Serve static files (HTML, CSS, JS)
31
  app.mount("/static", StaticFiles(directory="static"), name="static")
32
 
33
+ # Translation models
 
 
34
  translation_models = {
35
  "fr": "Helsinki-NLP/opus-mt-en-fr",
36
  "es": "Helsinki-NLP/opus-mt-en-es",
37
  "de": "Helsinki-NLP/opus-mt-en-de"
38
  }
39
 
40
+ # Retry logic for model loading
41
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
42
+ def load_model_with_retry(model_name, task, use_fast=True):
43
+ logger.info(f"Loading model: {model_name}")
44
+ return pipeline(task, model=model_name, use_fast=use_fast)
45
+
46
+ # Lazy-loading pipelines
47
+ @lru_cache(maxsize=1)
48
+ def get_multimodal_pipeline():
49
+ return load_model_with_retry("Salesforce/blip-image-captioning-base", "image-to-text")
50
+
51
+ @lru_cache(maxsize=1)
52
+ def get_text_pipeline():
53
+ return load_model_with_retry("t5-small", "text2text-generation")
54
+
55
+ @lru_cache(maxsize=3)
56
+ def get_translation_pipeline(target_language):
57
+ model_name = translation_models.get(target_language, "Helsinki-NLP/opus-mt-en-de")
58
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
59
+ model = MarianMTModel.from_pretrained(model_name)
60
+ return pipeline("translation_en_to_xx", model=model, tokenizer=tokenizer)
61
+
62
+ # Root endpoint
63
  @app.get("/")
64
  def read_root():
65
  return RedirectResponse(url="/static/index.html")
66
 
67
+ # Summarize text endpoint
68
  @app.post("/summarize")
69
+ async def summarize_text(file: UploadFile = File(None), text: str = Form(None)):
70
+ try:
71
+ if file:
72
+ logger.info(f"Received document for summarization: {file.filename}")
 
 
 
 
 
73
  text = await extract_text_from_file(file)
74
+ elif text:
75
+ logger.info("Received manual text for summarization")
76
+ else:
77
+ raise HTTPException(status_code=400, detail="No file or text provided")
 
 
 
 
78
 
79
+ text_pipeline = get_text_pipeline()
80
  summary = text_pipeline(f"summarize: {text}", max_length=100)
 
81
  return {"summary": summary[0]['generated_text']}
82
  except Exception as e:
83
  logger.error(f"Error during summarization: {e}")
84
+ raise HTTPException(status_code=500, detail="Failed to summarize text. Please try again.")
85
 
86
+ # Image captioning endpoint
87
  @app.post("/caption")
88
  async def caption_image(file: UploadFile = File(...)):
 
89
  try:
90
+ logger.info(f"Received image for captioning: {file.filename}")
91
  image_data = await file.read()
92
  image = Image.open(io.BytesIO(image_data))
93
 
 
95
  if image.format not in ["JPEG", "PNG"]:
96
  raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
97
 
98
+ multimodal_pipeline = get_multimodal_pipeline()
99
  caption = multimodal_pipeline(image)
 
100
  return {"caption": caption[0]['generated_text']}
101
  except Exception as e:
102
  logger.error(f"Error during image captioning: {e}")
103
  raise HTTPException(status_code=400, detail=str(e))
104
 
105
+ # Translation endpoint
106
  @app.post("/translate")
107
+ async def translate_document(file: UploadFile = File(None), text: str = Form(None), target_language: str = Form(...)):
 
 
 
 
 
 
108
  try:
109
+ if file:
110
+ logger.info(f"Received document for translation: {file.filename}")
111
+ text = await extract_text_from_file(file)
112
+ elif not text:
113
+ raise HTTPException(status_code=400, detail="No text or file provided")
 
 
 
 
 
 
 
114
 
115
+ translation_pipeline = get_translation_pipeline(target_language)
116
+ translated = translation_pipeline(text)
117
+ return {"translated_text": translated[0]['translation_text']}
118
  except Exception as e:
119
+ logger.error(f"Error during translation: {e}")
120
+ raise HTTPException(status_code=500, detail="Failed to translate text. Please try again.")
121
 
122
+ # Question answering endpoint
123
  @app.post("/answer")
124
+ async def answer_question(file: UploadFile = File(None), text: str = Form(None), question: str = Form(...)):
125
+ try:
126
+ if file:
127
+ logger.info(f"Received document for question answering: {file.filename}")
 
 
 
 
128
  text = await extract_text_from_file(file)
129
+ elif text:
130
+ logger.info("Received manual text for question answering")
131
+ else:
132
+ raise HTTPException(status_code=400, detail="No file or text provided")
 
 
 
133
 
134
+ text_pipeline = get_text_pipeline()
135
  answer = text_pipeline(f"question: {question} context: {text}")
 
136
  return {"answer": answer[0]['generated_text']}
137
  except Exception as e:
138
  logger.error(f"Error during question answering: {e}")
139
+ raise HTTPException(status_code=500, detail="Failed to answer the question. Please try again.")
140
 
141
+ # Visual question answering endpoint
142
  @app.post("/vqa")
143
  async def visual_question_answering(file: UploadFile = File(...), question: str = Form(...)):
 
 
144
  try:
145
+ logger.info(f"Received image for visual question answering: {file.filename}")
146
+ logger.info(f"Received question: {question}")
147
  image_data = await file.read()
148
  image = Image.open(io.BytesIO(image_data))
149
 
 
151
  if image.format not in ["JPEG", "PNG"]:
152
  raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
153
 
154
+ multimodal_pipeline = get_multimodal_pipeline()
155
  answer = multimodal_pipeline(image, question=question)
 
156
  return {"answer": answer[0]['generated_text']}
157
  except Exception as e:
158
  logger.error(f"Error during visual question answering: {e}")
159
  raise HTTPException(status_code=400, detail=str(e))
160
 
161
+ # Data visualization endpoint
162
  @app.post("/visualize")
163
+ async def visualize_data(file: UploadFile = File(...), request: str = Form(...)):
 
 
 
 
 
 
164
  try:
165
+ logger.info(f"Received Excel file for visualization: {file.filename}")
166
+ logger.info(f"Received visualization request: {request}")
167
  df = pd.read_excel(io.BytesIO(await file.read()))
168
 
169
  if "bar" in request.lower():
 
191
  plt.show()
192
  """
193
 
194
+ return {"code": code}
 
 
 
 
195
  except Exception as e:
196
  logger.error(f"Error during visualization code generation: {e}")
197
+ raise HTTPException(status_code=500, detail="Failed to generate visualization code. Please try again.")
198
 
199
+ # Helper function to extract text from files
200
  async def extract_text_from_file(file: UploadFile):
201
  try:
202
  file_content = await file.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  if file.filename.endswith(".pdf"):
204
+ doc = fitz.open(stream=file_content, filetype="pdf")
205
+ text = ""
206
+ for page in doc:
207
+ text += page.get_text()
208
+ return text
 
 
 
 
 
 
 
 
209
  elif file.filename.endswith(".docx"):
210
+ doc = Document(io.BytesIO(file_content))
211
+ return "\n".join([para.text for para in doc.paragraphs])
 
 
 
 
 
212
  elif file.filename.endswith(".txt"):
213
+ return file_content.decode("utf-8")
214
+ else:
215
+ raise HTTPException(status_code=400, detail="Unsupported file format")
 
 
216
  except Exception as e:
217
  logger.error(f"Error extracting text from file: {e}")
218
+ raise HTTPException(status_code=500, detail="Failed to extract text from file. Please try again.")
219
 
220
+ # Run the application
221
  if __name__ == "__main__":
222
  import uvicorn
223
  uvicorn.run(app, host="0.0.0.0", port=7860)