Hadiil commited on
Commit
0af552a
·
verified ·
1 Parent(s): 892d092

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -91
app.py CHANGED
@@ -1,17 +1,18 @@
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
2
  from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import RedirectResponse, JSONResponse
4
- from transformers import pipeline, MarianMTModel, MarianTokenizer
5
  import logging
6
  from PIL import Image
7
  import io
8
  from docx import Document
9
  import fitz # PyMuPDF
10
  import pandas as pd
11
- from tenacity import retry, stop_after_attempt, wait_exponential
12
- from functools import lru_cache
13
- from fastapi.middleware.cors import CORSMiddleware
14
- import os
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
@@ -19,76 +20,51 @@ logger = logging.getLogger(__name__)
19
 
20
  app = FastAPI()
21
 
22
- # Add CORS middleware
23
- app.add_middleware(
24
- CORSMiddleware,
25
- allow_origins=["*"], # Allow all origins (replace with your frontend URL in production)
26
- allow_credentials=True,
27
- allow_methods=["*"],
28
- allow_headers=["*"],
29
- )
30
-
31
  # Serve static files (HTML, CSS, JS)
32
  app.mount("/static", StaticFiles(directory="static"), name="static")
33
 
34
- # Translation models
 
 
35
  translation_models = {
36
  "fr": "Helsinki-NLP/opus-mt-en-fr",
37
  "es": "Helsinki-NLP/opus-mt-en-es",
38
  "de": "Helsinki-NLP/opus-mt-en-de"
39
  }
40
 
41
- # Retry logic for model loading
42
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
43
- def load_model_with_retry(model_name, task, use_fast=True):
44
- logger.info(f"Loading model: {model_name}")
45
- return pipeline(task, model=model_name, use_fast=use_fast)
46
-
47
- # Lazy-loading pipelines
48
- @lru_cache(maxsize=1)
49
- def get_multimodal_pipeline():
50
- return load_model_with_retry("Salesforce/blip-image-captioning-base", "image-to-text")
51
-
52
- @lru_cache(maxsize=1)
53
- def get_text_pipeline():
54
- return load_model_with_retry("t5-small", "text2text-generation")
55
-
56
- @lru_cache(maxsize=3)
57
- def get_translation_pipeline(target_language):
58
- model_name = translation_models.get(target_language, "Helsinki-NLP/opus-mt-en-de")
59
- tokenizer = MarianTokenizer.from_pretrained(model_name)
60
- model = MarianMTModel.from_pretrained(model_name)
61
- return pipeline("translation_en_to_xx", model=model, tokenizer=tokenizer)
62
-
63
- # Root endpoint
64
  @app.get("/")
65
  def read_root():
66
  return RedirectResponse(url="/static/index.html")
67
 
68
- # Summarize text endpoint
69
  @app.post("/summarize")
70
- async def summarize_text(file: UploadFile = File(None), text: str = Form(None)):
71
- try:
72
- if file:
73
- logger.info(f"Received document for summarization: {file.filename}")
 
 
 
74
  text = await extract_text_from_file(file)
75
- elif text:
76
- logger.info("Received manual text for summarization")
77
- else:
78
- raise HTTPException(status_code=400, detail="No file or text provided")
 
 
 
79
 
80
- text_pipeline = get_text_pipeline()
81
  summary = text_pipeline(f"summarize: {text}", max_length=100)
 
82
  return {"summary": summary[0]['generated_text']}
83
  except Exception as e:
84
  logger.error(f"Error during summarization: {e}")
85
- raise HTTPException(status_code=500, detail="Failed to summarize text. Please try again.")
86
 
87
- # Image captioning endpoint
88
  @app.post("/caption")
89
  async def caption_image(file: UploadFile = File(...)):
 
90
  try:
91
- logger.info(f"Received image for captioning: {file.filename}")
92
  image_data = await file.read()
93
  image = Image.open(io.BytesIO(image_data))
94
 
@@ -96,55 +72,71 @@ async def caption_image(file: UploadFile = File(...)):
96
  if image.format not in ["JPEG", "PNG"]:
97
  raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
98
 
99
- multimodal_pipeline = get_multimodal_pipeline()
100
  caption = multimodal_pipeline(image)
 
101
  return {"caption": caption[0]['generated_text']}
102
  except Exception as e:
103
  logger.error(f"Error during image captioning: {e}")
104
  raise HTTPException(status_code=400, detail=str(e))
105
 
106
- # Translation endpoint
107
  @app.post("/translate")
108
- async def translate_document(file: UploadFile = File(None), text: str = Form(None), target_language: str = Form(...)):
 
 
 
 
 
 
109
  try:
110
- if file:
111
- logger.info(f"Received document for translation: {file.filename}")
112
- text = await extract_text_from_file(file)
113
- elif not text:
114
- raise HTTPException(status_code=400, detail="No text or file provided")
115
 
116
- translation_pipeline = get_translation_pipeline(target_language)
117
- translated = translation_pipeline(text)
118
- return {"translated_text": translated[0]['translation_text']}
 
 
 
 
 
 
 
 
 
119
  except Exception as e:
120
- logger.error(f"Error during translation: {e}")
121
- raise HTTPException(status_code=500, detail="Failed to translate text. Please try again.")
122
 
123
- # Question answering endpoint
124
  @app.post("/answer")
125
- async def answer_question(file: UploadFile = File(None), text: str = Form(None), question: str = Form(...)):
126
- try:
127
- if file:
128
- logger.info(f"Received document for question answering: {file.filename}")
 
 
 
 
129
  text = await extract_text_from_file(file)
130
- elif text:
131
- logger.info("Received manual text for question answering")
132
- else:
133
- raise HTTPException(status_code=400, detail="No file or text provided")
 
 
 
134
 
135
- text_pipeline = get_text_pipeline()
136
  answer = text_pipeline(f"question: {question} context: {text}")
 
137
  return {"answer": answer[0]['generated_text']}
138
  except Exception as e:
139
  logger.error(f"Error during question answering: {e}")
140
- raise HTTPException(status_code=500, detail="Failed to answer the question. Please try again.")
141
 
142
- # Visual question answering endpoint
143
  @app.post("/vqa")
144
  async def visual_question_answering(file: UploadFile = File(...), question: str = Form(...)):
 
 
145
  try:
146
- logger.info(f"Received image for visual question answering: {file.filename}")
147
- logger.info(f"Received question: {question}")
148
  image_data = await file.read()
149
  image = Image.open(io.BytesIO(image_data))
150
 
@@ -152,19 +144,22 @@ async def visual_question_answering(file: UploadFile = File(...), question: str
152
  if image.format not in ["JPEG", "PNG"]:
153
  raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
154
 
155
- multimodal_pipeline = get_multimodal_pipeline()
156
  answer = multimodal_pipeline(image, question=question)
 
157
  return {"answer": answer[0]['generated_text']}
158
  except Exception as e:
159
  logger.error(f"Error during visual question answering: {e}")
160
  raise HTTPException(status_code=400, detail=str(e))
161
 
162
- # Data visualization endpoint
163
  @app.post("/visualize")
164
- async def visualize_data(file: UploadFile = File(...), request: str = Form(...)):
 
 
 
 
 
 
165
  try:
166
- logger.info(f"Received Excel file for visualization: {file.filename}")
167
- logger.info(f"Received visualization request: {request}")
168
  df = pd.read_excel(io.BytesIO(await file.read()))
169
 
170
  if "bar" in request.lower():
@@ -192,12 +187,15 @@ sns.pairplot(df)
192
  plt.show()
193
  """
194
 
195
- return {"code": code}
 
 
 
 
196
  except Exception as e:
197
  logger.error(f"Error during visualization code generation: {e}")
198
- raise HTTPException(status_code=500, detail="Failed to generate visualization code. Please try again.")
199
 
200
- # Helper function to extract text from files
201
  async def extract_text_from_file(file: UploadFile):
202
  try:
203
  file_content = await file.read()
@@ -209,16 +207,16 @@ async def extract_text_from_file(file: UploadFile):
209
  return text
210
  elif file.filename.endswith(".docx"):
211
  doc = Document(io.BytesIO(file_content))
212
- return "\n".join([para.text for para in doc.paragraphs])
 
213
  elif file.filename.endswith(".txt"):
214
  return file_content.decode("utf-8")
215
  else:
216
- raise HTTPException(status_code=400, detail="Unsupported file format")
217
  except Exception as e:
218
  logger.error(f"Error extracting text from file: {e}")
219
- raise HTTPException(status_code=500, detail="Failed to extract text from file. Please try again.")
220
 
221
- # Run the application
222
  if __name__ == "__main__":
223
  import uvicorn
224
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import os
2
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.responses import RedirectResponse
5
+ from transformers import pipeline
6
  import logging
7
  from PIL import Image
8
  import io
9
  from docx import Document
10
  import fitz # PyMuPDF
11
  import pandas as pd
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ import uuid
15
+ from transformers import MarianMTModel, MarianTokenizer
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
 
20
 
21
  app = FastAPI()
22
 
 
 
 
 
 
 
 
 
 
23
  # Serve static files (HTML, CSS, JS)
24
  app.mount("/static", StaticFiles(directory="static"), name="static")
25
 
26
+ # Load models
27
+ multimodal_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", use_fast=True)
28
+ text_pipeline = pipeline("text2text-generation", model="t5-small", use_fast=True)
29
  translation_models = {
30
  "fr": "Helsinki-NLP/opus-mt-en-fr",
31
  "es": "Helsinki-NLP/opus-mt-en-es",
32
  "de": "Helsinki-NLP/opus-mt-en-de"
33
  }
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @app.get("/")
36
  def read_root():
37
  return RedirectResponse(url="/static/index.html")
38
 
 
39
  @app.post("/summarize")
40
+ async def summarize_text(
41
+ file: UploadFile = File(None),
42
+ text: str = Form(None)
43
+ ):
44
+ if file:
45
+ logger.info(f"Received document for summarization: {file.filename}")
46
+ try:
47
  text = await extract_text_from_file(file)
48
+ except Exception as e:
49
+ logger.error(f"Error extracting text from file: {e}")
50
+ raise HTTPException(status_code=400, detail=str(e))
51
+ elif text:
52
+ logger.info("Received manual text for summarization")
53
+ else:
54
+ raise HTTPException(status_code=400, detail="No file or text provided")
55
 
56
+ try:
57
  summary = text_pipeline(f"summarize: {text}", max_length=100)
58
+ logger.info(f"Generated summary: {summary[0]['generated_text']}")
59
  return {"summary": summary[0]['generated_text']}
60
  except Exception as e:
61
  logger.error(f"Error during summarization: {e}")
62
+ raise HTTPException(status_code=500, detail=str(e))
63
 
 
64
  @app.post("/caption")
65
  async def caption_image(file: UploadFile = File(...)):
66
+ logger.info(f"Received image for captioning: {file.filename}")
67
  try:
 
68
  image_data = await file.read()
69
  image = Image.open(io.BytesIO(image_data))
70
 
 
72
  if image.format not in ["JPEG", "PNG"]:
73
  raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
74
 
 
75
  caption = multimodal_pipeline(image)
76
+ logger.info(f"Generated caption: {caption[0]['generated_text']}")
77
  return {"caption": caption[0]['generated_text']}
78
  except Exception as e:
79
  logger.error(f"Error during image captioning: {e}")
80
  raise HTTPException(status_code=400, detail=str(e))
81
 
 
82
  @app.post("/translate")
83
+ async def translate_document(
84
+ file: UploadFile = File(...),
85
+ target_language: str = Form(...)
86
+ ):
87
+ logger.info(f"Received document for translation: {file.filename}")
88
+ logger.info(f"Target language: {target_language}")
89
+
90
  try:
91
+ text = await extract_text_from_file(file)
 
 
 
 
92
 
93
+ if target_language in translation_models:
94
+ model_name = translation_models[target_language]
95
+ else:
96
+ model_name = "Helsinki-NLP/opus-mt-en-de" # Default to German
97
+
98
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
99
+ model = MarianMTModel.from_pretrained(model_name)
100
+
101
+ translated = model.generate(**tokenizer(text, return_tensors="pt", truncation=True))
102
+ translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
103
+
104
+ return {"translated_text": translated_text}
105
  except Exception as e:
106
+ logger.error(f"Error during document translation: {e}")
107
+ raise HTTPException(status_code=500, detail=str(e))
108
 
 
109
  @app.post("/answer")
110
+ async def answer_question(
111
+ file: UploadFile = File(None),
112
+ text: str = Form(None),
113
+ question: str = Form(...)
114
+ ):
115
+ if file:
116
+ logger.info(f"Received document for question answering: {file.filename}")
117
+ try:
118
  text = await extract_text_from_file(file)
119
+ except Exception as e:
120
+ logger.error(f"Error extracting text from file: {e}")
121
+ raise HTTPException(status_code=400, detail=str(e))
122
+ elif text:
123
+ logger.info("Received manual text for question answering")
124
+ else:
125
+ raise HTTPException(status_code=400, detail="No file or text provided")
126
 
127
+ try:
128
  answer = text_pipeline(f"question: {question} context: {text}")
129
+ logger.info(f"Generated answer: {answer[0]['generated_text']}")
130
  return {"answer": answer[0]['generated_text']}
131
  except Exception as e:
132
  logger.error(f"Error during question answering: {e}")
133
+ raise HTTPException(status_code=500, detail=str(e))
134
 
 
135
  @app.post("/vqa")
136
  async def visual_question_answering(file: UploadFile = File(...), question: str = Form(...)):
137
+ logger.info(f"Received image for visual question answering: {file.filename}")
138
+ logger.info(f"Received question: {question}")
139
  try:
 
 
140
  image_data = await file.read()
141
  image = Image.open(io.BytesIO(image_data))
142
 
 
144
  if image.format not in ["JPEG", "PNG"]:
145
  raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
146
 
 
147
  answer = multimodal_pipeline(image, question=question)
148
+ logger.info(f"Generated answer: {answer[0]['generated_text']}")
149
  return {"answer": answer[0]['generated_text']}
150
  except Exception as e:
151
  logger.error(f"Error during visual question answering: {e}")
152
  raise HTTPException(status_code=400, detail=str(e))
153
 
 
154
  @app.post("/visualize")
155
+ async def visualize_data(
156
+ file: UploadFile = File(...),
157
+ request: str = Form(...)
158
+ ):
159
+ logger.info(f"Received Excel file for visualization: {file.filename}")
160
+ logger.info(f"Received visualization request: {request}")
161
+
162
  try:
 
 
163
  df = pd.read_excel(io.BytesIO(await file.read()))
164
 
165
  if "bar" in request.lower():
 
187
  plt.show()
188
  """
189
 
190
+ code_filename = f"visualization_{uuid.uuid4()}.py"
191
+ with open(code_filename, "w") as f:
192
+ f.write(code)
193
+
194
+ return {"code": code, "filename": code_filename}
195
  except Exception as e:
196
  logger.error(f"Error during visualization code generation: {e}")
197
+ raise HTTPException(status_code=500, detail=str(e))
198
 
 
199
  async def extract_text_from_file(file: UploadFile):
200
  try:
201
  file_content = await file.read()
 
207
  return text
208
  elif file.filename.endswith(".docx"):
209
  doc = Document(io.BytesIO(file_content))
210
+ text = "\n".join([para.text for para in doc.paragraphs])
211
+ return text
212
  elif file.filename.endswith(".txt"):
213
  return file_content.decode("utf-8")
214
  else:
215
+ raise ValueError("Unsupported file format. Please upload a PDF, DOCX, or TXT file.")
216
  except Exception as e:
217
  logger.error(f"Error extracting text from file: {e}")
218
+ raise HTTPException(status_code=400, detail=str(e))
219
 
 
220
  if __name__ == "__main__":
221
  import uvicorn
222
  uvicorn.run(app, host="0.0.0.0", port=7860)