Hadiil commited on
Commit
29ec3ba
·
verified ·
1 Parent(s): 069103b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -83
app.py CHANGED
@@ -1,15 +1,18 @@
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse, JSONResponse
4
- from transformers import pipeline, MarianMTModel, MarianTokenizer
5
  import logging
6
  from PIL import Image
7
  import io
8
  from docx import Document
9
  import fitz # PyMuPDF
10
  import pandas as pd
11
- from tenacity import retry, stop_after_attempt, wait_exponential
12
- from functools import lru_cache
 
 
13
  from fastapi.middleware.cors import CORSMiddleware
14
 
15
  # Configure logging
@@ -30,119 +33,148 @@ app.add_middleware(
30
  # Serve static files (HTML, CSS, JS)
31
  app.mount("/static", StaticFiles(directory="static"), name="static")
32
 
33
- # Translation models
 
 
34
  translation_models = {
35
  "fr": "Helsinki-NLP/opus-mt-en-fr",
36
  "es": "Helsinki-NLP/opus-mt-en-es",
37
  "de": "Helsinki-NLP/opus-mt-en-de"
38
  }
39
 
40
- # Retry logic for model loading
41
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
42
- def load_model_with_retry(model_name, task, use_fast=True):
43
- logger.info(f"Loading model: {model_name}")
44
- return pipeline(task, model=model_name, use_fast=use_fast)
45
-
46
- # Lazy-loading pipelines
47
- @lru_cache(maxsize=1)
48
- def get_multimodal_pipeline():
49
- return load_model_with_retry("Salesforce/blip-image-captioning-base", "image-to-text")
50
-
51
- @lru_cache(maxsize=1)
52
- def get_text_pipeline():
53
- return load_model_with_retry("t5-small", "text2text-generation")
54
-
55
- @lru_cache(maxsize=3)
56
- def get_translation_pipeline(target_language):
57
- model_name = translation_models.get(target_language, "Helsinki-NLP/opus-mt-en-de")
58
- tokenizer = MarianTokenizer.from_pretrained(model_name)
59
- model = MarianMTModel.from_pretrained(model_name)
60
- return pipeline("translation_en_to_xx", model=model, tokenizer=tokenizer)
61
-
62
- # Root endpoint
63
  @app.get("/")
64
  def read_root():
65
  return RedirectResponse(url="/static/index.html")
66
 
67
- # Summarize text endpoint
68
  @app.post("/summarize")
69
- async def summarize_text(file: UploadFile = File(None), text: str = Form(None)):
70
- try:
71
- if file:
 
 
 
 
 
 
72
  text = await extract_text_from_file(file)
73
- elif not text:
74
- raise HTTPException(status_code=400, detail="No text or file provided")
 
 
 
 
 
 
75
 
76
- text_pipeline = get_text_pipeline()
77
  summary = text_pipeline(f"summarize: {text}", max_length=100)
 
78
  return {"summary": summary[0]['generated_text']}
79
  except Exception as e:
80
- logger.error(f"Error in summarization: {e}")
81
- raise HTTPException(status_code=500, detail="Failed to summarize text. Please try again.")
82
 
83
- # Image captioning endpoint
84
  @app.post("/caption")
85
  async def caption_image(file: UploadFile = File(...)):
 
86
  try:
87
  image_data = await file.read()
88
  image = Image.open(io.BytesIO(image_data))
89
- multimodal_pipeline = get_multimodal_pipeline()
 
 
 
 
90
  caption = multimodal_pipeline(image)
 
91
  return {"caption": caption[0]['generated_text']}
92
  except Exception as e:
93
- logger.error(f"Error in image captioning: {e}")
94
- raise HTTPException(status_code=500, detail="Failed to generate caption. Please try again.")
95
 
96
- # Translation endpoint
97
  @app.post("/translate")
98
- async def translate_document(file: UploadFile = File(None), text: str = Form(None), target_language: str = Form(...)):
 
 
 
 
 
 
99
  try:
100
- if file:
101
- text = await extract_text_from_file(file)
102
- elif not text:
103
- raise HTTPException(status_code=400, detail="No text or file provided")
 
 
 
 
 
 
 
 
104
 
105
- translation_pipeline = get_translation_pipeline(target_language)
106
- translated = translation_pipeline(text)
107
- return {"translated_text": translated[0]['translation_text']}
108
  except Exception as e:
109
- logger.error(f"Error in translation: {e}")
110
- raise HTTPException(status_code=500, detail="Failed to translate text. Please try again.")
111
 
112
- # Question answering endpoint
113
  @app.post("/answer")
114
- async def answer_question(file: UploadFile = File(None), text: str = Form(None), question: str = Form(...)):
115
- try:
116
- if file:
 
 
 
 
 
117
  text = await extract_text_from_file(file)
118
- elif not text:
119
- raise HTTPException(status_code=400, detail="No text or file provided")
 
 
 
 
 
120
 
121
- text_pipeline = get_text_pipeline()
122
  answer = text_pipeline(f"question: {question} context: {text}")
 
123
  return {"answer": answer[0]['generated_text']}
124
  except Exception as e:
125
- logger.error(f"Error in question answering: {e}")
126
- raise HTTPException(status_code=500, detail="Failed to answer the question. Please try again.")
127
 
128
- # Visual question answering endpoint
129
  @app.post("/vqa")
130
  async def visual_question_answering(file: UploadFile = File(...), question: str = Form(...)):
 
 
131
  try:
132
  image_data = await file.read()
133
  image = Image.open(io.BytesIO(image_data))
134
- multimodal_pipeline = get_multimodal_pipeline()
 
 
 
 
135
  answer = multimodal_pipeline(image, question=question)
 
136
  return {"answer": answer[0]['generated_text']}
137
  except Exception as e:
138
- logger.error(f"Error in visual question answering: {e}")
139
- raise HTTPException(status_code=500, detail="Failed to answer the question. Please try again.")
140
 
141
- # Data visualization endpoint
142
  @app.post("/visualize")
143
- async def visualize_data(file: UploadFile = File(...), request: str = Form(...)):
 
 
 
 
 
 
144
  try:
145
  df = pd.read_excel(io.BytesIO(await file.read()))
 
146
  if "bar" in request.lower():
147
  code = f"""
148
  import matplotlib.pyplot as plt
@@ -167,33 +199,65 @@ import seaborn as sns
167
  sns.pairplot(df)
168
  plt.show()
169
  """
170
- return {"code": code}
 
 
 
 
 
171
  except Exception as e:
172
- logger.error(f"Error in data visualization: {e}")
173
- raise HTTPException(status_code=500, detail="Failed to generate visualization code. Please try again.")
174
 
175
- # Helper function to extract text from files
176
  async def extract_text_from_file(file: UploadFile):
177
  try:
178
  file_content = await file.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  if file.filename.endswith(".pdf"):
180
- doc = fitz.open(stream=file_content, filetype="pdf")
181
- text = ""
182
- for page in doc:
183
- text += page.get_text()
184
- return text
 
 
 
 
 
 
 
 
185
  elif file.filename.endswith(".docx"):
186
- doc = Document(io.BytesIO(file_content))
187
- return "\n".join([para.text for para in doc.paragraphs])
 
 
 
 
 
188
  elif file.filename.endswith(".txt"):
189
- return file_content.decode("utf-8")
190
- else:
191
- raise HTTPException(status_code=400, detail="Unsupported file format")
 
 
192
  except Exception as e:
193
  logger.error(f"Error extracting text from file: {e}")
194
- raise HTTPException(status_code=500, detail="Failed to extract text from file. Please try again.")
195
 
196
- # Run the application
197
  if __name__ == "__main__":
198
  import uvicorn
199
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import os
2
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.responses import RedirectResponse, JSONResponse
5
+ from transformers import pipeline
6
  import logging
7
  from PIL import Image
8
  import io
9
  from docx import Document
10
  import fitz # PyMuPDF
11
  import pandas as pd
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ import uuid
15
+ from transformers import MarianMTModel, MarianTokenizer
16
  from fastapi.middleware.cors import CORSMiddleware
17
 
18
  # Configure logging
 
33
  # Serve static files (HTML, CSS, JS)
34
  app.mount("/static", StaticFiles(directory="static"), name="static")
35
 
36
+ # Load models
37
+ multimodal_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", use_fast=True)
38
+ text_pipeline = pipeline("text2text-generation", model="t5-small", use_fast=True)
39
  translation_models = {
40
  "fr": "Helsinki-NLP/opus-mt-en-fr",
41
  "es": "Helsinki-NLP/opus-mt-en-es",
42
  "de": "Helsinki-NLP/opus-mt-en-de"
43
  }
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @app.get("/")
46
  def read_root():
47
  return RedirectResponse(url="/static/index.html")
48
 
 
49
  @app.post("/summarize")
50
+ async def summarize_text(
51
+ file: UploadFile = File(None),
52
+ text: str = Form(None)
53
+ ):
54
+ logger.info(f"Received request: file={file}, text={text}") # Debugging
55
+
56
+ if file:
57
+ logger.info(f"Received document for summarization: {file.filename}")
58
+ try:
59
  text = await extract_text_from_file(file)
60
+ except Exception as e:
61
+ logger.error(f"Error extracting text from file: {e}")
62
+ raise HTTPException(status_code=400, detail=str(e))
63
+ elif text:
64
+ logger.info("Received manual text for summarization")
65
+ else:
66
+ logger.error("No file or text provided") # Debugging
67
+ raise HTTPException(status_code=400, detail="No file or text provided")
68
 
69
+ try:
70
  summary = text_pipeline(f"summarize: {text}", max_length=100)
71
+ logger.info(f"Generated summary: {summary[0]['generated_text']}")
72
  return {"summary": summary[0]['generated_text']}
73
  except Exception as e:
74
+ logger.error(f"Error during summarization: {e}")
75
+ raise HTTPException(status_code=500, detail=str(e))
76
 
 
77
  @app.post("/caption")
78
  async def caption_image(file: UploadFile = File(...)):
79
+ logger.info(f"Received image for captioning: {file.filename}")
80
  try:
81
  image_data = await file.read()
82
  image = Image.open(io.BytesIO(image_data))
83
+
84
+ # Validate image format
85
+ if image.format not in ["JPEG", "PNG"]:
86
+ raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
87
+
88
  caption = multimodal_pipeline(image)
89
+ logger.info(f"Generated caption: {caption[0]['generated_text']}")
90
  return {"caption": caption[0]['generated_text']}
91
  except Exception as e:
92
+ logger.error(f"Error during image captioning: {e}")
93
+ raise HTTPException(status_code=400, detail=str(e))
94
 
 
95
  @app.post("/translate")
96
+ async def translate_document(
97
+ file: UploadFile = File(...),
98
+ target_language: str = Form(...)
99
+ ):
100
+ logger.info(f"Received document for translation: {file.filename}")
101
+ logger.info(f"Target language: {target_language}")
102
+
103
  try:
104
+ text = await extract_text_from_file(file)
105
+
106
+ if target_language in translation_models:
107
+ model_name = translation_models[target_language]
108
+ else:
109
+ model_name = "Helsinki-NLP/opus-mt-en-de" # Default to German
110
+
111
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
112
+ model = MarianMTModel.from_pretrained(model_name)
113
+
114
+ translated = model.generate(**tokenizer(text, return_tensors="pt", truncation=True))
115
+ translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
116
 
117
+ return {"translated_text": translated_text}
 
 
118
  except Exception as e:
119
+ logger.error(f"Error during document translation: {e}")
120
+ raise HTTPException(status_code=500, detail=str(e))
121
 
 
122
  @app.post("/answer")
123
+ async def answer_question(
124
+ file: UploadFile = File(None),
125
+ text: str = Form(None),
126
+ question: str = Form(...)
127
+ ):
128
+ if file:
129
+ logger.info(f"Received document for question answering: {file.filename}")
130
+ try:
131
  text = await extract_text_from_file(file)
132
+ except Exception as e:
133
+ logger.error(f"Error extracting text from file: {e}")
134
+ raise HTTPException(status_code=400, detail=str(e))
135
+ elif text:
136
+ logger.info("Received manual text for question answering")
137
+ else:
138
+ raise HTTPException(status_code=400, detail="No file or text provided")
139
 
140
+ try:
141
  answer = text_pipeline(f"question: {question} context: {text}")
142
+ logger.info(f"Generated answer: {answer[0]['generated_text']}")
143
  return {"answer": answer[0]['generated_text']}
144
  except Exception as e:
145
+ logger.error(f"Error during question answering: {e}")
146
+ raise HTTPException(status_code=500, detail=str(e))
147
 
 
148
  @app.post("/vqa")
149
  async def visual_question_answering(file: UploadFile = File(...), question: str = Form(...)):
150
+ logger.info(f"Received image for visual question answering: {file.filename}")
151
+ logger.info(f"Received question: {question}")
152
  try:
153
  image_data = await file.read()
154
  image = Image.open(io.BytesIO(image_data))
155
+
156
+ # Validate image format
157
+ if image.format not in ["JPEG", "PNG"]:
158
+ raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
159
+
160
  answer = multimodal_pipeline(image, question=question)
161
+ logger.info(f"Generated answer: {answer[0]['generated_text']}")
162
  return {"answer": answer[0]['generated_text']}
163
  except Exception as e:
164
+ logger.error(f"Error during visual question answering: {e}")
165
+ raise HTTPException(status_code=400, detail=str(e))
166
 
 
167
  @app.post("/visualize")
168
+ async def visualize_data(
169
+ file: UploadFile = File(...),
170
+ request: str = Form(...)
171
+ ):
172
+ logger.info(f"Received Excel file for visualization: {file.filename}")
173
+ logger.info(f"Received visualization request: {request}")
174
+
175
  try:
176
  df = pd.read_excel(io.BytesIO(await file.read()))
177
+
178
  if "bar" in request.lower():
179
  code = f"""
180
  import matplotlib.pyplot as plt
 
199
  sns.pairplot(df)
200
  plt.show()
201
  """
202
+
203
+ code_filename = f"visualization_{uuid.uuid4()}.py"
204
+ with open(code_filename, "w") as f:
205
+ f.write(code)
206
+
207
+ return {"code": code, "filename": code_filename}
208
  except Exception as e:
209
+ logger.error(f"Error during visualization code generation: {e}")
210
+ raise HTTPException(status_code=500, detail=str(e))
211
 
 
212
  async def extract_text_from_file(file: UploadFile):
213
  try:
214
  file_content = await file.read()
215
+ if not file_content:
216
+ logger.error("Uploaded file is empty.")
217
+ raise ValueError("Uploaded file is empty.")
218
+
219
+ # Check file size (e.g., limit to 10MB)
220
+ if len(file_content) > 10 * 1024 * 1024: # 10MB
221
+ logger.error("File size exceeds the limit (10MB).")
222
+ raise ValueError("File size exceeds the limit (10MB).")
223
+
224
+ # Check file type
225
+ if not file.filename.lower().endswith((".pdf", ".docx", ".txt")):
226
+ logger.error(f"Unsupported files format: {file.filename}")
227
+ raise ValueError("Unsupported file format. Please upload a PDF, DOCX, or TXT file.")
228
+
229
  if file.filename.endswith(".pdf"):
230
+ try:
231
+ # Log the first few bytes of the file for debugging
232
+ logger.info(f"First 100 bytes of the file: {file_content[:100]}")
233
+
234
+ # Attempt to open the PDF
235
+ doc = fitz.open(stream=file_content, filetype="pdf")
236
+ text = ""
237
+ for page in doc:
238
+ text += page.get_text()
239
+ return text
240
+ except Exception as e:
241
+ logger.error(f"Error reading PDF file: {e}")
242
+ raise ValueError("Failed to read PDF file. It might be corrupted or not a valid PDF.")
243
  elif file.filename.endswith(".docx"):
244
+ try:
245
+ doc = Document(io.BytesIO(file_content))
246
+ text = "\n".join([para.text for para in doc.paragraphs])
247
+ return text
248
+ except Exception as e:
249
+ logger.error(f"Error reading DOCX file: {e}")
250
+ raise ValueError("Failed to read DOCX file. It might be corrupted or not a valid DOCX.")
251
  elif file.filename.endswith(".txt"):
252
+ try:
253
+ return file_content.decode("utf-8")
254
+ except Exception as e:
255
+ logger.error(f"Error reading TXT file: {e}")
256
+ raise ValueError("Failed to read TXT file. It might be corrupted or not a valid TXT.")
257
  except Exception as e:
258
  logger.error(f"Error extracting text from file: {e}")
259
+ raise HTTPException(status_code=400, detail=str(e))
260
 
 
261
  if __name__ == "__main__":
262
  import uvicorn
263
+ uvicorn.run(app, host="0.0.0.0", port=7860