Hadiil commited on
Commit
ebb73a9
·
verified ·
1 Parent(s): 9f1ee9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -67
app.py CHANGED
@@ -1,15 +1,15 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse, JSONResponse
4
- from transformers import pipeline
5
  import logging
6
  from PIL import Image
7
  import io
8
  from docx import Document
9
  import fitz # PyMuPDF
10
  import pandas as pd
11
- import uuid
12
- from transformers import MarianMTModel, MarianTokenizer
13
  from fastapi.middleware.cors import CORSMiddleware
14
 
15
  # Configure logging
@@ -30,72 +30,121 @@ app.add_middleware(
30
  # Serve static files (HTML, CSS, JS)
31
  app.mount("/static", StaticFiles(directory="static"), name="static")
32
 
33
- # Load models
34
- multimodal_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", use_fast=True)
35
- text_pipeline = pipeline("text2text-generation", model="t5-small", use_fast=True)
36
  translation_models = {
37
  "fr": "Helsinki-NLP/opus-mt-en-fr",
38
  "es": "Helsinki-NLP/opus-mt-en-es",
39
  "de": "Helsinki-NLP/opus-mt-en-de"
40
  }
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @app.get("/")
43
  def read_root():
44
  return RedirectResponse(url="/static/index.html")
45
 
 
46
  @app.post("/summarize")
47
  async def summarize_text(file: UploadFile = File(None), text: str = Form(None)):
48
- if file:
49
- text = await extract_text_from_file(file)
50
- elif not text:
51
- raise HTTPException(status_code=400, detail="No text or files provided")
52
-
53
- summary = text_pipeline(f"summarize: {text}", max_length=100)
54
- return {"summary": summary[0]['generated_text']}
55
-
 
 
 
 
 
 
56
  @app.post("/caption")
57
  async def caption_image(file: UploadFile = File(...)):
58
- image_data = await file.read()
59
- image = Image.open(io.BytesIO(image_data))
60
- caption = multimodal_pipeline(image)
61
- return {"caption": caption[0]['generated_text']}
62
-
 
 
 
 
 
 
63
  @app.post("/translate")
64
  async def translate_document(file: UploadFile = File(None), text: str = Form(None), target_language: str = Form(...)):
65
- if file:
66
- text = await extract_text_from_file(file)
67
- elif not text:
68
- raise HTTPException(status_code=400, detail="No text or file provided")
69
-
70
- model_name = translation_models.get(target_language, "Helsinki-NLP/opus-mt-en-de")
71
- tokenizer = MarianTokenizer.from_pretrained(model_name)
72
- model = MarianMTModel.from_pretrained(model_name)
73
- translated = model.generate(**tokenizer(text, return_tensors="pt", truncation=True))
74
- translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
75
- return {"translated_text": translated_text}
76
-
 
 
77
  @app.post("/answer")
78
  async def answer_question(file: UploadFile = File(None), text: str = Form(None), question: str = Form(...)):
79
- if file:
80
- text = await extract_text_from_file(file)
81
- elif not text:
82
- raise HTTPException(status_code=400, detail="No text or file provided")
83
-
84
- answer = text_pipeline(f"question: {question} context: {text}")
85
- return {"answer": answer[0]['generated_text']}
86
-
 
 
 
 
 
 
87
  @app.post("/vqa")
88
  async def visual_question_answering(file: UploadFile = File(...), question: str = Form(...)):
89
- image_data = await file.read()
90
- image = Image.open(io.BytesIO(image_data))
91
- answer = multimodal_pipeline(image, question=question)
92
- return {"answer": answer[0]['generated_text']}
93
-
 
 
 
 
 
 
94
  @app.post("/visualize")
95
  async def visualize_data(file: UploadFile = File(...), request: str = Form(...)):
96
- df = pd.read_excel(io.BytesIO(await file.read()))
97
- if "bar" in request.lower():
98
- code = f"""
 
99
  import matplotlib.pyplot as plt
100
  plt.bar(df['{df.columns[0]}'], df['{df.columns[1]}'])
101
  plt.xlabel('{df.columns[0]}')
@@ -103,8 +152,8 @@ plt.ylabel('{df.columns[1]}')
103
  plt.title('Bar Chart')
104
  plt.show()
105
  """
106
- elif "line" in request.lower():
107
- code = f"""
108
  import matplotlib.pyplot as plt
109
  plt.plot(df['{df.columns[0]}'], df['{df.columns[1]}'])
110
  plt.xlabel('{df.columns[0]}')
@@ -112,30 +161,39 @@ plt.ylabel('{df.columns[1]}')
112
  plt.title('Line Chart')
113
  plt.show()
114
  """
115
- else:
116
- code = f"""
117
  import seaborn as sns
118
  sns.pairplot(df)
119
  plt.show()
120
  """
121
- return {"code": code}
 
 
 
122
 
 
123
  async def extract_text_from_file(file: UploadFile):
124
- file_content = await file.read()
125
- if file.filename.endswith(".pdf"):
126
- doc = fitz.open(stream=file_content, filetype="pdf")
127
- text = ""
128
- for page in doc:
129
- text += page.get_text()
130
- return text
131
- elif file.filename.endswith(".docx"):
132
- doc = Document(io.BytesIO(file_content))
133
- return "\n".join([para.text for para in doc.paragraphs])
134
- elif file.filename.endswith(".txt"):
135
- return file_content.decode("utf-8")
136
- else:
137
- raise HTTPException(status_code=400, detail="Unsupported file format")
138
-
 
 
 
 
 
139
  if __name__ == "__main__":
140
  import uvicorn
141
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse, JSONResponse
4
+ from transformers import pipeline, MarianMTModel, MarianTokenizer
5
  import logging
6
  from PIL import Image
7
  import io
8
  from docx import Document
9
  import fitz # PyMuPDF
10
  import pandas as pd
11
+ from tenacity import retry, stop_after_attempt, wait_exponential
12
+ from functools import lru_cache
13
  from fastapi.middleware.cors import CORSMiddleware
14
 
15
  # Configure logging
 
30
  # Serve static files (HTML, CSS, JS)
31
  app.mount("/static", StaticFiles(directory="static"), name="static")
32
 
33
+ # Translation models
 
 
34
  translation_models = {
35
  "fr": "Helsinki-NLP/opus-mt-en-fr",
36
  "es": "Helsinki-NLP/opus-mt-en-es",
37
  "de": "Helsinki-NLP/opus-mt-en-de"
38
  }
39
 
40
+ # Retry logic for model loading
41
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
42
+ def load_model_with_retry(model_name, task, use_fast=True):
43
+ logger.info(f"Loading model: {model_name}")
44
+ return pipeline(task, model=model_name, use_fast=use_fast)
45
+
46
+ # Lazy-loading pipelines
47
+ @lru_cache(maxsize=1)
48
+ def get_multimodal_pipeline():
49
+ return load_model_with_retry("Salesforce/blip-image-captioning-base", "image-to-text")
50
+
51
+ @lru_cache(maxsize=1)
52
+ def get_text_pipeline():
53
+ return load_model_with_retry("t5-small", "text2text-generation")
54
+
55
+ @lru_cache(maxsize=3)
56
+ def get_translation_pipeline(target_language):
57
+ model_name = translation_models.get(target_language, "Helsinki-NLP/opus-mt-en-de")
58
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
59
+ model = MarianMTModel.from_pretrained(model_name)
60
+ return pipeline("translation_en_to_xx", model=model, tokenizer=tokenizer)
61
+
62
+ # Root endpoint
63
  @app.get("/")
64
  def read_root():
65
  return RedirectResponse(url="/static/index.html")
66
 
67
+ # Summarize text endpoint
68
  @app.post("/summarize")
69
  async def summarize_text(file: UploadFile = File(None), text: str = Form(None)):
70
+ try:
71
+ if file:
72
+ text = await extract_text_from_file(file)
73
+ elif not text:
74
+ raise HTTPException(status_code=400, detail="No text or file provided")
75
+
76
+ text_pipeline = get_text_pipeline()
77
+ summary = text_pipeline(f"summarize: {text}", max_length=100)
78
+ return {"summary": summary[0]['generated_text']}
79
+ except Exception as e:
80
+ logger.error(f"Error in summarization: {e}")
81
+ raise HTTPException(status_code=500, detail="Failed to summarize text. Please try again.")
82
+
83
+ # Image captioning endpoint
84
  @app.post("/caption")
85
  async def caption_image(file: UploadFile = File(...)):
86
+ try:
87
+ image_data = await file.read()
88
+ image = Image.open(io.BytesIO(image_data))
89
+ multimodal_pipeline = get_multimodal_pipeline()
90
+ caption = multimodal_pipeline(image)
91
+ return {"caption": caption[0]['generated_text']}
92
+ except Exception as e:
93
+ logger.error(f"Error in image captioning: {e}")
94
+ raise HTTPException(status_code=500, detail="Failed to generate caption. Please try again.")
95
+
96
+ # Translation endpoint
97
  @app.post("/translate")
98
  async def translate_document(file: UploadFile = File(None), text: str = Form(None), target_language: str = Form(...)):
99
+ try:
100
+ if file:
101
+ text = await extract_text_from_file(file)
102
+ elif not text:
103
+ raise HTTPException(status_code=400, detail="No text or file provided")
104
+
105
+ translation_pipeline = get_translation_pipeline(target_language)
106
+ translated = translation_pipeline(text)
107
+ return {"translated_text": translated[0]['translation_text']}
108
+ except Exception as e:
109
+ logger.error(f"Error in translation: {e}")
110
+ raise HTTPException(status_code=500, detail="Failed to translate text. Please try again.")
111
+
112
+ # Question answering endpoint
113
  @app.post("/answer")
114
  async def answer_question(file: UploadFile = File(None), text: str = Form(None), question: str = Form(...)):
115
+ try:
116
+ if file:
117
+ text = await extract_text_from_file(file)
118
+ elif not text:
119
+ raise HTTPException(status_code=400, detail="No text or file provided")
120
+
121
+ text_pipeline = get_text_pipeline()
122
+ answer = text_pipeline(f"question: {question} context: {text}")
123
+ return {"answer": answer[0]['generated_text']}
124
+ except Exception as e:
125
+ logger.error(f"Error in question answering: {e}")
126
+ raise HTTPException(status_code=500, detail="Failed to answer the question. Please try again.")
127
+
128
+ # Visual question answering endpoint
129
  @app.post("/vqa")
130
  async def visual_question_answering(file: UploadFile = File(...), question: str = Form(...)):
131
+ try:
132
+ image_data = await file.read()
133
+ image = Image.open(io.BytesIO(image_data))
134
+ multimodal_pipeline = get_multimodal_pipeline()
135
+ answer = multimodal_pipeline(image, question=question)
136
+ return {"answer": answer[0]['generated_text']}
137
+ except Exception as e:
138
+ logger.error(f"Error in visual question answering: {e}")
139
+ raise HTTPException(status_code=500, detail="Failed to answer the question. Please try again.")
140
+
141
+ # Data visualization endpoint
142
  @app.post("/visualize")
143
  async def visualize_data(file: UploadFile = File(...), request: str = Form(...)):
144
+ try:
145
+ df = pd.read_excel(io.BytesIO(await file.read()))
146
+ if "bar" in request.lower():
147
+ code = f"""
148
  import matplotlib.pyplot as plt
149
  plt.bar(df['{df.columns[0]}'], df['{df.columns[1]}'])
150
  plt.xlabel('{df.columns[0]}')
 
152
  plt.title('Bar Chart')
153
  plt.show()
154
  """
155
+ elif "line" in request.lower():
156
+ code = f"""
157
  import matplotlib.pyplot as plt
158
  plt.plot(df['{df.columns[0]}'], df['{df.columns[1]}'])
159
  plt.xlabel('{df.columns[0]}')
 
161
  plt.title('Line Chart')
162
  plt.show()
163
  """
164
+ else:
165
+ code = f"""
166
  import seaborn as sns
167
  sns.pairplot(df)
168
  plt.show()
169
  """
170
+ return {"code": code}
171
+ except Exception as e:
172
+ logger.error(f"Error in data visualization: {e}")
173
+ raise HTTPException(status_code=500, detail="Failed to generate visualization code. Please try again.")
174
 
175
+ # Helper function to extract text from files
176
  async def extract_text_from_file(file: UploadFile):
177
+ try:
178
+ file_content = await file.read()
179
+ if file.filename.endswith(".pdf"):
180
+ doc = fitz.open(stream=file_content, filetype="pdf")
181
+ text = ""
182
+ for page in doc:
183
+ text += page.get_text()
184
+ return text
185
+ elif file.filename.endswith(".docx"):
186
+ doc = Document(io.BytesIO(file_content))
187
+ return "\n".join([para.text for para in doc.paragraphs])
188
+ elif file.filename.endswith(".txt"):
189
+ return file_content.decode("utf-8")
190
+ else:
191
+ raise HTTPException(status_code=400, detail="Unsupported file format")
192
+ except Exception as e:
193
+ logger.error(f"Error extracting text from file: {e}")
194
+ raise HTTPException(status_code=500, detail="Failed to extract text from file. Please try again.")
195
+
196
+ # Run the application
197
  if __name__ == "__main__":
198
  import uvicorn
199
  uvicorn.run(app, host="0.0.0.0", port=7860)