Hadiil commited on
Commit
afbc76b
·
verified ·
1 Parent(s): c66450a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -50
app.py CHANGED
@@ -23,9 +23,13 @@ app = FastAPI()
23
  # Serve static files (HTML, CSS, JS)
24
  app.mount("/static", StaticFiles(directory="static"), name="static")
25
 
26
- # Load models
27
- multimodal_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", use_fast=True)
28
- text_pipeline = pipeline("text2text-generation", model="t5-small", use_fast=True)
 
 
 
 
29
  translation_models = {
30
  "fr": "Helsinki-NLP/opus-mt-en-fr",
31
  "es": "Helsinki-NLP/opus-mt-en-es",
@@ -34,16 +38,18 @@ translation_models = {
34
 
35
  @app.get("/")
36
  def read_root():
 
37
  return RedirectResponse(url="/static/index.html")
38
 
39
  @app.post("/summarize")
40
  async def summarize_text(
41
- file: UploadFile = File(None),
42
- text: str = Form(None)
43
  ):
44
  if file:
45
  logger.info(f"Received document for summarization: {file.filename}")
46
  try:
 
47
  text = await extract_text_from_file(file)
48
  except Exception as e:
49
  logger.error(f"Error extracting text from file: {e}")
@@ -54,6 +60,7 @@ async def summarize_text(
54
  raise HTTPException(status_code=400, detail="No file or text provided")
55
 
56
  try:
 
57
  summary = text_pipeline(f"summarize: {text}", max_length=100)
58
  logger.info(f"Generated summary: {summary[0]['generated_text']}")
59
  return {"summary": summary[0]['generated_text']}
@@ -65,56 +72,28 @@ async def summarize_text(
65
  async def caption_image(file: UploadFile = File(...)):
66
  logger.info(f"Received image for captioning: {file.filename}")
67
  try:
 
68
  image_data = await file.read()
69
  image = Image.open(io.BytesIO(image_data))
70
 
71
- # Validate image format
72
- if image.format not in ["JPEG", "PNG"]:
73
- raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
74
-
75
  caption = multimodal_pipeline(image)
76
  logger.info(f"Generated caption: {caption[0]['generated_text']}")
77
  return {"caption": caption[0]['generated_text']}
78
  except Exception as e:
79
  logger.error(f"Error during image captioning: {e}")
80
- raise HTTPException(status_code=400, detail=str(e))
81
-
82
- @app.post("/translate")
83
- async def translate_document(
84
- file: UploadFile = File(...),
85
- target_language: str = Form(...)
86
- ):
87
- logger.info(f"Received document for translation: {file.filename}")
88
- logger.info(f"Target language: {target_language}")
89
-
90
- try:
91
- text = await extract_text_from_file(file)
92
-
93
- if target_language in translation_models:
94
- model_name = translation_models[target_language]
95
- else:
96
- model_name = "Helsinki-NLP/opus-mt-en-de" # Default to German
97
-
98
- tokenizer = MarianTokenizer.from_pretrained(model_name)
99
- model = MarianMTModel.from_pretrained(model_name)
100
-
101
- translated = model.generate(**tokenizer(text, return_tensors="pt", truncation=True))
102
- translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
103
-
104
- return {"translated_text": translated_text}
105
- except Exception as e:
106
- logger.error(f"Error during document translation: {e}")
107
  raise HTTPException(status_code=500, detail=str(e))
108
 
109
  @app.post("/answer")
110
  async def answer_question(
111
- file: UploadFile = File(None),
112
- text: str = Form(None),
113
- question: str = Form(...)
114
  ):
115
  if file:
116
  logger.info(f"Received document for question answering: {file.filename}")
117
  try:
 
118
  text = await extract_text_from_file(file)
119
  except Exception as e:
120
  logger.error(f"Error extracting text from file: {e}")
@@ -125,6 +104,7 @@ async def answer_question(
125
  raise HTTPException(status_code=400, detail="No file or text provided")
126
 
127
  try:
 
128
  answer = text_pipeline(f"question: {question} context: {text}")
129
  logger.info(f"Generated answer: {answer[0]['generated_text']}")
130
  return {"answer": answer[0]['generated_text']}
@@ -137,19 +117,17 @@ async def visual_question_answering(file: UploadFile = File(...), question: str
137
  logger.info(f"Received image for visual question answering: {file.filename}")
138
  logger.info(f"Received question: {question}")
139
  try:
 
140
  image_data = await file.read()
141
  image = Image.open(io.BytesIO(image_data))
142
 
143
- # Validate image format
144
- if image.format not in ["JPEG", "PNG"]:
145
- raise ValueError("Unsupported image format. Please upload a JPEG or PNG file.")
146
-
147
  answer = multimodal_pipeline(image, question=question)
148
  logger.info(f"Generated answer: {answer[0]['generated_text']}")
149
  return {"answer": answer[0]['generated_text']}
150
  except Exception as e:
151
  logger.error(f"Error during visual question answering: {e}")
152
- raise HTTPException(status_code=400, detail=str(e))
153
 
154
  @app.post("/visualize")
155
  async def visualize_data(
@@ -160,8 +138,10 @@ async def visualize_data(
160
  logger.info(f"Received visualization request: {request}")
161
 
162
  try:
 
163
  df = pd.read_excel(io.BytesIO(await file.read()))
164
 
 
165
  if "bar" in request.lower():
166
  code = f"""
167
  import matplotlib.pyplot as plt
@@ -187,6 +167,7 @@ sns.pairplot(df)
187
  plt.show()
188
  """
189
 
 
190
  code_filename = f"visualization_{uuid.uuid4()}.py"
191
  with open(code_filename, "w") as f:
192
  f.write(code)
@@ -196,27 +177,56 @@ plt.show()
196
  logger.error(f"Error during visualization code generation: {e}")
197
  raise HTTPException(status_code=500, detail=str(e))
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  async def extract_text_from_file(file: UploadFile):
200
  try:
201
- file_content = await file.read()
202
  if file.filename.endswith(".pdf"):
203
- doc = fitz.open(stream=file_content, filetype="pdf")
204
  text = ""
205
  for page in doc:
206
  text += page.get_text()
207
  return text
208
  elif file.filename.endswith(".docx"):
209
- doc = Document(io.BytesIO(file_content))
210
  text = "\n".join([para.text for para in doc.paragraphs])
211
  return text
212
- elif file.filename.endswith(".txt"):
213
- return file_content.decode("utf-8")
214
  else:
215
- raise ValueError("Unsupported file format. Please upload a PDF, DOCX, or TXT file.")
216
  except Exception as e:
217
  logger.error(f"Error extracting text from file: {e}")
218
  raise HTTPException(status_code=400, detail=str(e))
219
 
 
220
  if __name__ == "__main__":
221
  import uvicorn
222
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
23
  # Serve static files (HTML, CSS, JS)
24
  app.mount("/static", StaticFiles(directory="static"), name="static")
25
 
26
+ # Load a multimodal model for image captioning and visual question answering
27
+ multimodal_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
28
+
29
+ # Load a text-based model for summarization and text question answering
30
+ text_pipeline = pipeline("text2text-generation", model="t5-small")
31
+
32
+ # Load a translation model (initialized dynamically based on target language)
33
  translation_models = {
34
  "fr": "Helsinki-NLP/opus-mt-en-fr",
35
  "es": "Helsinki-NLP/opus-mt-en-es",
 
38
 
39
  @app.get("/")
40
  def read_root():
41
+ # Redirect to the static HTML file
42
  return RedirectResponse(url="/static/index.html")
43
 
44
  @app.post("/summarize")
45
  async def summarize_text(
46
+ file: UploadFile = File(None), # Optional file upload
47
+ text: str = Form(None) # Optional manual text input
48
  ):
49
  if file:
50
  logger.info(f"Received document for summarization: {file.filename}")
51
  try:
52
+ # Extract text from the document
53
  text = await extract_text_from_file(file)
54
  except Exception as e:
55
  logger.error(f"Error extracting text from file: {e}")
 
60
  raise HTTPException(status_code=400, detail="No file or text provided")
61
 
62
  try:
63
+ # Use the text pipeline to summarize the text
64
  summary = text_pipeline(f"summarize: {text}", max_length=100)
65
  logger.info(f"Generated summary: {summary[0]['generated_text']}")
66
  return {"summary": summary[0]['generated_text']}
 
72
  async def caption_image(file: UploadFile = File(...)):
73
  logger.info(f"Received image for captioning: {file.filename}")
74
  try:
75
+ # Read the image file
76
  image_data = await file.read()
77
  image = Image.open(io.BytesIO(image_data))
78
 
79
+ # Use the multimodal pipeline to generate a caption for the image
 
 
 
80
  caption = multimodal_pipeline(image)
81
  logger.info(f"Generated caption: {caption[0]['generated_text']}")
82
  return {"caption": caption[0]['generated_text']}
83
  except Exception as e:
84
  logger.error(f"Error during image captioning: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  raise HTTPException(status_code=500, detail=str(e))
86
 
87
  @app.post("/answer")
88
  async def answer_question(
89
+ file: UploadFile = File(None), # Optional file upload
90
+ text: str = Form(None), # Optional manual text input
91
+ question: str = Form(...) # Required question
92
  ):
93
  if file:
94
  logger.info(f"Received document for question answering: {file.filename}")
95
  try:
96
+ # Extract text from the document
97
  text = await extract_text_from_file(file)
98
  except Exception as e:
99
  logger.error(f"Error extracting text from file: {e}")
 
104
  raise HTTPException(status_code=400, detail="No file or text provided")
105
 
106
  try:
107
+ # Use the text pipeline to answer the question
108
  answer = text_pipeline(f"question: {question} context: {text}")
109
  logger.info(f"Generated answer: {answer[0]['generated_text']}")
110
  return {"answer": answer[0]['generated_text']}
 
117
  logger.info(f"Received image for visual question answering: {file.filename}")
118
  logger.info(f"Received question: {question}")
119
  try:
120
+ # Read the image file
121
  image_data = await file.read()
122
  image = Image.open(io.BytesIO(image_data))
123
 
124
+ # Use the multimodal pipeline to answer the question about the image
 
 
 
125
  answer = multimodal_pipeline(image, question=question)
126
  logger.info(f"Generated answer: {answer[0]['generated_text']}")
127
  return {"answer": answer[0]['generated_text']}
128
  except Exception as e:
129
  logger.error(f"Error during visual question answering: {e}")
130
+ raise HTTPException(status_code=500, detail=str(e))
131
 
132
  @app.post("/visualize")
133
  async def visualize_data(
 
138
  logger.info(f"Received visualization request: {request}")
139
 
140
  try:
141
+ # Read the Excel file
142
  df = pd.read_excel(io.BytesIO(await file.read()))
143
 
144
+ # Generate visualization code based on the request
145
  if "bar" in request.lower():
146
  code = f"""
147
  import matplotlib.pyplot as plt
 
167
  plt.show()
168
  """
169
 
170
+ # Save the generated code to a file (optional)
171
  code_filename = f"visualization_{uuid.uuid4()}.py"
172
  with open(code_filename, "w") as f:
173
  f.write(code)
 
177
  logger.error(f"Error during visualization code generation: {e}")
178
  raise HTTPException(status_code=500, detail=str(e))
179
 
180
+ @app.post("/translate")
181
+ async def translate_document(
182
+ file: UploadFile = File(...),
183
+ target_language: str = Form(...)
184
+ ):
185
+ logger.info(f"Received document for translation: {file.filename}")
186
+ logger.info(f"Target language: {target_language}")
187
+
188
+ try:
189
+ # Extract text from the document
190
+ text = await extract_text_from_file(file)
191
+
192
+ # Load a translation model based on the target language
193
+ if target_language in translation_models:
194
+ model_name = translation_models[target_language]
195
+ else:
196
+ model_name = "Helsinki-NLP/opus-mt-en-de" # Default to German
197
+
198
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
199
+ model = MarianMTModel.from_pretrained(model_name)
200
+
201
+ # Translate the text
202
+ translated = model.generate(**tokenizer(text, return_tensors="pt", truncation=True))
203
+ translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
204
+
205
+ return {"translated_text": translated_text}
206
+ except Exception as e:
207
+ logger.error(f"Error during document translation: {e}")
208
+ raise HTTPException(status_code=500, detail=str(e))
209
+
210
+ # Helper function to extract text from files
211
  async def extract_text_from_file(file: UploadFile):
212
  try:
 
213
  if file.filename.endswith(".pdf"):
214
+ doc = fitz.open(stream=await file.read(), filetype="pdf")
215
  text = ""
216
  for page in doc:
217
  text += page.get_text()
218
  return text
219
  elif file.filename.endswith(".docx"):
220
+ doc = Document(io.BytesIO(await file.read()))
221
  text = "\n".join([para.text for para in doc.paragraphs])
222
  return text
 
 
223
  else:
224
+ raise ValueError("Unsupported file format. Please upload a PDF or DOCX file.")
225
  except Exception as e:
226
  logger.error(f"Error extracting text from file: {e}")
227
  raise HTTPException(status_code=400, detail=str(e))
228
 
229
+ # Hugging Face Spaces expects the app to be served on port 7860
230
  if __name__ == "__main__":
231
  import uvicorn
232
  uvicorn.run(app, host="0.0.0.0", port=7860)