Hadiil commited on
Commit
8299dd0
·
verified ·
1 Parent(s): 5a54a75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -110
app.py CHANGED
@@ -1,174 +1,157 @@
1
  import os
2
- from fastapi import FastAPI, UploadFile, File, HTTPException, Form
3
  from fastapi.staticfiles import StaticFiles
4
- from fastapi.responses import RedirectResponse, JSONResponse
5
  from transformers import pipeline
6
- import logging
7
  from PIL import Image
8
  import io
9
- from docx import Document
10
  import fitz # PyMuPDF
 
11
  import pandas as pd
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
- import uuid
15
- from transformers import MarianMTModel, MarianTokenizer
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
- app = FastAPI()
22
-
23
- # Serve static files (HTML, CSS, JS)
24
  app.mount("/static", StaticFiles(directory="static"), name="static")
25
 
26
- # Initialize models
27
  try:
28
- # Image captioning and VQA
29
  image_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
30
-
31
- # Text processing
32
  text_pipeline = pipeline("text2text-generation", model="t5-small")
33
-
34
- # Translation models dictionary
35
- translation_models = {
36
- "fr": "Helsinki-NLP/opus-mt-en-fr",
37
- "es": "Helsinki-NLP/opus-mt-en-es",
38
- "de": "Helsinki-NLP/opus-mt-en-de"
39
- }
40
-
41
- logger.info("All models loaded successfully")
42
  except Exception as e:
43
- logger.error(f"Model loading failed: {str(e)}")
44
- raise RuntimeError(f"Failed to initialize models: {str(e)}")
45
-
46
- @app.get("/")
47
- def read_root():
48
- return RedirectResponse(url="/static/index.html")
49
 
50
- @app.get("/health")
51
- def health_check():
52
- return {"status": "healthy", "models_loaded": True}
 
 
 
 
 
 
53
 
54
- @app.post("/summarize")
55
- async def summarize_text(
56
- file: UploadFile = File(None),
57
- text: str = Form(None)
58
  ):
 
 
 
 
 
59
  try:
60
  if file:
61
- text = await extract_text_from_file(file)
62
- elif not text:
63
- raise HTTPException(status_code=400, detail="No input provided")
64
 
65
- summary = text_pipeline(f"summarize: {text}", max_length=100)
66
- return {"summary": summary[0]['generated_text']}
 
 
67
  except Exception as e:
68
- logger.error(f"Summarization error: {str(e)}")
69
- raise HTTPException(status_code=500, detail=str(e))
70
 
71
- @app.post("/caption")
72
  async def caption_image(file: UploadFile = File(...)):
 
 
 
 
 
73
  try:
74
  image = Image.open(io.BytesIO(await file.read()))
75
- caption = image_pipeline(image)
76
- return {"caption": caption[0]['generated_text']}
77
  except Exception as e:
78
- logger.error(f"Captioning error: {str(e)}")
79
- raise HTTPException(status_code=500, detail=str(e))
80
 
81
- @app.post("/answer")
82
  async def answer_question(
83
- file: UploadFile = File(None),
84
- text: str = Form(None),
85
  question: str = Form(...)
86
  ):
 
 
 
 
 
87
  try:
88
  if file:
89
- text = await extract_text_from_file(file)
90
- elif not text:
91
- raise HTTPException(status_code=400, detail="No context provided")
92
 
93
- answer = text_pipeline(f"question: {question} context: {text}")
94
- return {"answer": answer[0]['generated_text']}
95
- except Exception as e:
96
- logger.error(f"QA error: {str(e)}")
97
- raise HTTPException(status_code=500, detail=str(e))
98
-
99
- @app.post("/vqa")
100
- async def visual_question_answering(
101
- file: UploadFile = File(...),
102
- question: str = Form(...)
103
- ):
104
- try:
105
- image = Image.open(io.BytesIO(await file.read()))
106
- answer = image_pipeline(image, question=question)
107
- return {"answer": answer[0]['generated_text']}
108
  except Exception as e:
109
- logger.error(f"VQA error: {str(e)}")
110
- raise HTTPException(status_code=500, detail=str(e))
111
 
112
- @app.post("/visualize")
113
- async def visualize_data(
114
  file: UploadFile = File(...),
115
- request: str = Form(...)
116
  ):
 
 
 
 
 
117
  try:
118
  df = pd.read_excel(io.BytesIO(await file.read()))
119
 
120
- if "bar" in request.lower():
121
  code = f"""import matplotlib.pyplot as plt
122
  plt.bar(df['{df.columns[0]}'], df['{df.columns[1]}'])
 
123
  plt.show()"""
124
  else:
125
  code = f"""import seaborn as sns
126
  sns.pairplot(df)
 
127
  plt.show()"""
128
 
129
- return {"code": code}
 
 
 
130
  except Exception as e:
131
- logger.error(f"Visualization error: {str(e)}")
132
- raise HTTPException(status_code=500, detail=str(e))
133
 
134
- @app.post("/translate")
135
- async def translate_document(
136
- file: UploadFile = File(...),
137
- target_language: str = Form(...)
138
- ):
139
- try:
140
- text = await extract_text_from_file(file)
141
-
142
- if target_language not in translation_models:
143
- raise HTTPException(status_code=400, detail="Unsupported language")
144
-
145
- tokenizer = MarianTokenizer.from_pretrained(translation_models[target_language])
146
- model = MarianMTModel.from_pretrained(translation_models[target_language])
147
-
148
- translated = model.generate(**tokenizer(text, return_tensors="pt", truncation=True))
149
- translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
150
-
151
- return {"translated_text": translated_text}
152
- except Exception as e:
153
- logger.error(f"Translation error: {str(e)}")
154
- raise HTTPException(status_code=500, detail=str(e))
155
-
156
- async def extract_text_from_file(file: UploadFile):
157
  try:
158
  content = await file.read()
159
 
160
  if file.filename.endswith(".pdf"):
161
- doc = fitz.open(stream=content, filetype="pdf")
162
- return " ".join([page.get_text() for page in doc])
163
  elif file.filename.endswith(".docx"):
164
  doc = Document(io.BytesIO(content))
165
- return "\n".join([para.text for para in doc.paragraphs])
166
  else:
167
  raise ValueError("Unsupported file format")
168
  except Exception as e:
169
- logger.error(f"File extraction error: {str(e)}")
170
- raise HTTPException(status_code=400, detail=str(e))
171
 
172
- if __name__ == "__main__":
173
- import uvicorn
174
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  import os
2
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.responses import HTMLResponse, JSONResponse
5
  from transformers import pipeline
 
6
  from PIL import Image
7
  import io
 
8
  import fitz # PyMuPDF
9
+ from docx import Document
10
  import pandas as pd
11
+ import logging
12
+ from typing import Optional
 
 
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ app = FastAPI(title="AI Web Services")
 
 
19
  app.mount("/static", StaticFiles(directory="static"), name="static")
20
 
21
+ # Initialize models (Spaces will cache these)
22
  try:
23
+ logger.info("Loading AI models...")
24
  image_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
 
 
25
  text_pipeline = pipeline("text2text-generation", model="t5-small")
26
+ logger.info("Models loaded successfully")
 
 
 
 
 
 
 
 
27
  except Exception as e:
28
+ logger.error(f"Model loading failed: {e}")
29
+ raise RuntimeError("Failed to initialize AI models")
 
 
 
 
30
 
31
+ @app.get("/", response_class=HTMLResponse)
32
+ async def home():
33
+ """Serve the frontend interface"""
34
+ try:
35
+ with open("static/index.html") as f:
36
+ return f.read()
37
+ except Exception as e:
38
+ logger.error(f"Failed to load frontend: {e}")
39
+ raise HTTPException(500, "Frontend loading failed")
40
 
41
+ @app.post("/api/summarize")
42
+ async def summarize(
43
+ file: Optional[UploadFile] = File(None),
44
+ text: Optional[str] = Form(None)
45
  ):
46
+ """
47
+ Summarize text or document
48
+ Accepts: PDF, DOCX or raw text
49
+ Returns: {'summary': str}
50
+ """
51
  try:
52
  if file:
53
+ text = await extract_text(file)
54
+ if not text:
55
+ raise HTTPException(400, "No text provided")
56
 
57
+ result = text_pipeline(f"summarize: {text}", max_length=150)
58
+ return JSONResponse({"summary": result[0]['generated_text']})
59
+ except HTTPException:
60
+ raise
61
  except Exception as e:
62
+ logger.error(f"Summarization error: {e}")
63
+ raise HTTPException(500, "Summarization failed")
64
 
65
+ @app.post("/api/caption")
66
  async def caption_image(file: UploadFile = File(...)):
67
+ """
68
+ Generate caption for image
69
+ Accepts: JPEG, PNG
70
+ Returns: {'caption': str}
71
+ """
72
  try:
73
  image = Image.open(io.BytesIO(await file.read()))
74
+ result = image_pipeline(image)
75
+ return JSONResponse({"caption": result[0]['generated_text']})
76
  except Exception as e:
77
+ logger.error(f"Captioning error: {e}")
78
+ raise HTTPException(500, "Image captioning failed")
79
 
80
+ @app.post("/api/answer")
81
  async def answer_question(
82
+ file: Optional[UploadFile] = File(None),
83
+ text: Optional[str] = Form(None),
84
  question: str = Form(...)
85
  ):
86
+ """
87
+ Answer questions about text/document
88
+ Accepts: PDF, DOCX or raw text + question
89
+ Returns: {'answer': str}
90
+ """
91
  try:
92
  if file:
93
+ text = await extract_text(file)
94
+ if not text:
95
+ raise HTTPException(400, "No text provided")
96
 
97
+ result = text_pipeline(f"question: {question} context: {text}")
98
+ return JSONResponse({"answer": result[0]['generated_text']})
99
+ except HTTPException:
100
+ raise
 
 
 
 
 
 
 
 
 
 
 
101
  except Exception as e:
102
+ logger.error(f"QA error: {e}")
103
+ raise HTTPException(500, "Question answering failed")
104
 
105
+ @app.post("/api/visualize")
106
+ async def generate_visualization(
107
  file: UploadFile = File(...),
108
+ chart_type: str = Form("bar")
109
  ):
110
+ """
111
+ Generate visualization code for Excel data
112
+ Accepts: XLSX, CSV
113
+ Returns: {'code': str, 'columns': list}
114
+ """
115
  try:
116
  df = pd.read_excel(io.BytesIO(await file.read()))
117
 
118
+ if chart_type.lower() == "bar":
119
  code = f"""import matplotlib.pyplot as plt
120
  plt.bar(df['{df.columns[0]}'], df['{df.columns[1]}'])
121
+ plt.title('Bar Chart')
122
  plt.show()"""
123
  else:
124
  code = f"""import seaborn as sns
125
  sns.pairplot(df)
126
+ plt.title('Data Distribution')
127
  plt.show()"""
128
 
129
+ return JSONResponse({
130
+ "code": code,
131
+ "columns": list(df.columns)
132
+ })
133
  except Exception as e:
134
+ logger.error(f"Visualization error: {e}")
135
+ raise HTTPException(500, "Visualization code generation failed")
136
 
137
+ async def extract_text(file: UploadFile) -> str:
138
+ """Extract text from PDF or DOCX files"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  try:
140
  content = await file.read()
141
 
142
  if file.filename.endswith(".pdf"):
143
+ with fitz.open(stream=content, filetype="pdf") as doc:
144
+ return " ".join(page.get_text() for page in doc)
145
  elif file.filename.endswith(".docx"):
146
  doc = Document(io.BytesIO(content))
147
+ return "\n".join(p.text for p in doc.paragraphs)
148
  else:
149
  raise ValueError("Unsupported file format")
150
  except Exception as e:
151
+ logger.error(f"Text extraction failed: {e}")
152
+ raise HTTPException(400, f"Could not extract text: {e}")
153
 
154
+ # Health check endpoint
155
+ @app.get("/health")
156
+ async def health_check():
157
+ return JSONResponse({"status": "healthy", "models": "loaded"})