Hadiil commited on
Commit
5a54a75
·
verified ·
1 Parent(s): 47691ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -3
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
3
  from fastapi.staticfiles import StaticFiles
4
- from fastapi.responses import RedirectResponse
5
  from transformers import pipeline
6
  import logging
7
  from PIL import Image
@@ -23,5 +23,152 @@ app = FastAPI()
23
  # Serve static files (HTML, CSS, JS)
24
  app.mount("/static", StaticFiles(directory="static"), name="static")
25
 
26
- # Load a multimodal model for image captioning and visual question answering
27
- multimodal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.responses import RedirectResponse, JSONResponse
5
  from transformers import pipeline
6
  import logging
7
  from PIL import Image
 
23
  # Serve static files (HTML, CSS, JS)
24
  app.mount("/static", StaticFiles(directory="static"), name="static")
25
 
26
+ # Initialize models
27
+ try:
28
+ # Image captioning and VQA
29
+ image_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
30
+
31
+ # Text processing
32
+ text_pipeline = pipeline("text2text-generation", model="t5-small")
33
+
34
+ # Translation models dictionary
35
+ translation_models = {
36
+ "fr": "Helsinki-NLP/opus-mt-en-fr",
37
+ "es": "Helsinki-NLP/opus-mt-en-es",
38
+ "de": "Helsinki-NLP/opus-mt-en-de"
39
+ }
40
+
41
+ logger.info("All models loaded successfully")
42
+ except Exception as e:
43
+ logger.error(f"Model loading failed: {str(e)}")
44
+ raise RuntimeError(f"Failed to initialize models: {str(e)}")
45
+
46
+ @app.get("/")
47
+ def read_root():
48
+ return RedirectResponse(url="/static/index.html")
49
+
50
+ @app.get("/health")
51
+ def health_check():
52
+ return {"status": "healthy", "models_loaded": True}
53
+
54
+ @app.post("/summarize")
55
+ async def summarize_text(
56
+ file: UploadFile = File(None),
57
+ text: str = Form(None)
58
+ ):
59
+ try:
60
+ if file:
61
+ text = await extract_text_from_file(file)
62
+ elif not text:
63
+ raise HTTPException(status_code=400, detail="No input provided")
64
+
65
+ summary = text_pipeline(f"summarize: {text}", max_length=100)
66
+ return {"summary": summary[0]['generated_text']}
67
+ except Exception as e:
68
+ logger.error(f"Summarization error: {str(e)}")
69
+ raise HTTPException(status_code=500, detail=str(e))
70
+
71
+ @app.post("/caption")
72
+ async def caption_image(file: UploadFile = File(...)):
73
+ try:
74
+ image = Image.open(io.BytesIO(await file.read()))
75
+ caption = image_pipeline(image)
76
+ return {"caption": caption[0]['generated_text']}
77
+ except Exception as e:
78
+ logger.error(f"Captioning error: {str(e)}")
79
+ raise HTTPException(status_code=500, detail=str(e))
80
+
81
+ @app.post("/answer")
82
+ async def answer_question(
83
+ file: UploadFile = File(None),
84
+ text: str = Form(None),
85
+ question: str = Form(...)
86
+ ):
87
+ try:
88
+ if file:
89
+ text = await extract_text_from_file(file)
90
+ elif not text:
91
+ raise HTTPException(status_code=400, detail="No context provided")
92
+
93
+ answer = text_pipeline(f"question: {question} context: {text}")
94
+ return {"answer": answer[0]['generated_text']}
95
+ except Exception as e:
96
+ logger.error(f"QA error: {str(e)}")
97
+ raise HTTPException(status_code=500, detail=str(e))
98
+
99
+ @app.post("/vqa")
100
+ async def visual_question_answering(
101
+ file: UploadFile = File(...),
102
+ question: str = Form(...)
103
+ ):
104
+ try:
105
+ image = Image.open(io.BytesIO(await file.read()))
106
+ answer = image_pipeline(image, question=question)
107
+ return {"answer": answer[0]['generated_text']}
108
+ except Exception as e:
109
+ logger.error(f"VQA error: {str(e)}")
110
+ raise HTTPException(status_code=500, detail=str(e))
111
+
112
+ @app.post("/visualize")
113
+ async def visualize_data(
114
+ file: UploadFile = File(...),
115
+ request: str = Form(...)
116
+ ):
117
+ try:
118
+ df = pd.read_excel(io.BytesIO(await file.read()))
119
+
120
+ if "bar" in request.lower():
121
+ code = f"""import matplotlib.pyplot as plt
122
+ plt.bar(df['{df.columns[0]}'], df['{df.columns[1]}'])
123
+ plt.show()"""
124
+ else:
125
+ code = f"""import seaborn as sns
126
+ sns.pairplot(df)
127
+ plt.show()"""
128
+
129
+ return {"code": code}
130
+ except Exception as e:
131
+ logger.error(f"Visualization error: {str(e)}")
132
+ raise HTTPException(status_code=500, detail=str(e))
133
+
134
+ @app.post("/translate")
135
+ async def translate_document(
136
+ file: UploadFile = File(...),
137
+ target_language: str = Form(...)
138
+ ):
139
+ try:
140
+ text = await extract_text_from_file(file)
141
+
142
+ if target_language not in translation_models:
143
+ raise HTTPException(status_code=400, detail="Unsupported language")
144
+
145
+ tokenizer = MarianTokenizer.from_pretrained(translation_models[target_language])
146
+ model = MarianMTModel.from_pretrained(translation_models[target_language])
147
+
148
+ translated = model.generate(**tokenizer(text, return_tensors="pt", truncation=True))
149
+ translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
150
+
151
+ return {"translated_text": translated_text}
152
+ except Exception as e:
153
+ logger.error(f"Translation error: {str(e)}")
154
+ raise HTTPException(status_code=500, detail=str(e))
155
+
156
+ async def extract_text_from_file(file: UploadFile):
157
+ try:
158
+ content = await file.read()
159
+
160
+ if file.filename.endswith(".pdf"):
161
+ doc = fitz.open(stream=content, filetype="pdf")
162
+ return " ".join([page.get_text() for page in doc])
163
+ elif file.filename.endswith(".docx"):
164
+ doc = Document(io.BytesIO(content))
165
+ return "\n".join([para.text for para in doc.paragraphs])
166
+ else:
167
+ raise ValueError("Unsupported file format")
168
+ except Exception as e:
169
+ logger.error(f"File extraction error: {str(e)}")
170
+ raise HTTPException(status_code=400, detail=str(e))
171
+
172
+ if __name__ == "__main__":
173
+ import uvicorn
174
+ uvicorn.run(app, host="0.0.0.0", port=7860)