Feriel080 commited on
Commit
04ac6c3
·
verified ·
1 Parent(s): 0cb4f32

Upload 2 files

Browse files
Files changed (2) hide show
  1. backend/main.py +246 -0
  2. backend/utils.py +96 -0
backend/main.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form
2
+ from fastapi.responses import FileResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ import shutil
5
+ from pathlib import Path
6
+ from transformers import (
7
+ pipeline,
8
+ AutoProcessor,
9
+ AutoModelForVision2Seq,
10
+ M2M100ForConditionalGeneration,
11
+ M2M100Tokenizer,
12
+ )
13
+ from huggingface_hub import InferenceClient
14
+ from PIL import Image
15
+ import matplotlib.pyplot as plt
16
+ import seaborn as sns
17
+ import numpy as np
18
+ from utils import extract_text, save_file
19
+ import torch
20
+ import easyocr
21
+ from langdetect import detect, DetectorFactory # for language detection
22
+
23
+ app = FastAPI()
24
+
25
+ # Initialize Hugging Face models
26
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
27
+ processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
28
+ image_captioner = AutoModelForVision2Seq.from_pretrained(
29
+ "microsoft/kosmos-2-patch14-224",
30
+ use_safetensors=True,
31
+ trust_remote_code=True,
32
+ torch_dtype=torch.float16,
33
+ )
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ image_captioner = image_captioner.to(device)
36
+
37
+ tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
38
+ translation_model = M2M100ForConditionalGeneration.from_pretrained(
39
+ "facebook/m2m100_418M"
40
+ )
41
+ question_answering = pipeline(
42
+ "question-answering", model="bert-large-uncased-whole-word-masking-finetuned-squad"
43
+ )
44
+
45
+ DetectorFactory.seed = 0
46
+
47
+
48
+ # Directory to store uploaded and processed files
49
+ UPLOAD_DIR = Path("uploads")
50
+ PROCESSED_DIR = Path("processed")
51
+ UPLOAD_DIR.mkdir(exist_ok=True)
52
+ PROCESSED_DIR.mkdir(exist_ok=True)
53
+
54
+ app.mount(
55
+ "/assets", StaticFiles(directory="../frontend/assets", html=True), name="assets"
56
+ )
57
+ app.mount("/processed", StaticFiles(directory="processed"), name="processed")
58
+
59
+
60
+ @app.get("/")
61
+ async def serve_frontend():
62
+ return FileResponse("../frontend/index.html")
63
+
64
+
65
+ # List processed files
66
+ @app.get("/processed_files")
67
+ async def list_processed_files():
68
+ files = [f.name for f in PROCESSED_DIR.iterdir() if f.is_file()]
69
+ return {"files": files}
70
+
71
+
72
+ # Download a processed file
73
+ @app.get("/download/{filename}")
74
+ async def download_file(filename: str):
75
+ file_path = PROCESSED_DIR / filename
76
+ if not file_path.exists():
77
+ raise HTTPException(status_code=404, detail="File not found")
78
+ return FileResponse(file_path, filename=filename)
79
+
80
+
81
+ # Document & Image Analysis (Summarization & Interpretation)
82
+ @app.post("/docsum_imginter")
83
+ async def docsum_imginter(file: UploadFile = File(...), task: str = Form(...)):
84
+ file_type = file.filename.split(".")[-1].lower()
85
+ file_path = UPLOAD_DIR / file.filename
86
+ output_filename = f"summarized_{file.filename}"
87
+ output_path = PROCESSED_DIR / output_filename
88
+
89
+ # Save the uploaded file
90
+ with open(file_path, "wb") as f:
91
+ shutil.copyfileobj(file.file, f)
92
+
93
+ if file_type in ["docx", "xlsx", "pptx", "pdf", "txt"]:
94
+ if task.lower() == "summarize":
95
+ text = extract_text(file_path, file_type)
96
+ if text is None:
97
+ raise HTTPException(
98
+ status_code=400, detail="Failed to extract text from the document."
99
+ )
100
+ if not text.strip():
101
+ raise HTTPException(
102
+ status_code=400, detail="No text found in the document."
103
+ )
104
+
105
+ original_word_count = len(text.split())
106
+
107
+ if original_word_count < 150:
108
+ return {
109
+ "warning": "Document too short for meaningful summarization",
110
+ "original_text": text,
111
+ "word_count": original_word_count,
112
+ }
113
+
114
+ target_length = max(original_word_count // 2, 150)
115
+
116
+ summary = summarizer(
117
+ "Generate a detailed technical summary (150-200 words)" + text,
118
+ max_length=target_length,
119
+ min_length=target_length,
120
+ do_sample=False,
121
+ truncation=True,
122
+ )[0]["summary_text"]
123
+ save_file(summary, file_path, file_type, output_path)
124
+ return FileResponse(output_path, filename=output_filename)
125
+ else:
126
+ raise HTTPException(
127
+ status_code=400,
128
+ detail="Task not supported for documents. Use 'summarize'.",
129
+ )
130
+ elif file_type in ["png", "jpg", "jpeg"]:
131
+ if task.lower() == "interpretation":
132
+ image = Image.open(file_path)
133
+ inputs = processor(
134
+ text="Describe this image in detail including any text:",
135
+ images=image,
136
+ return_tensors="pt",
137
+ ).to(device)
138
+
139
+ generated_ids = image_captioner.generate(
140
+ pixel_values=inputs["pixel_values"],
141
+ input_ids=inputs["input_ids"],
142
+ attention_mask=inputs["attention_mask"],
143
+ max_new_tokens=200,
144
+ )
145
+
146
+ caption = processor.decode(generated_ids[0], skip_special_tokens=True)
147
+ return {"caption": caption}
148
+ else:
149
+ raise HTTPException(
150
+ status_code=400,
151
+ detail="Task not supported for images. Use 'interpretation'.",
152
+ )
153
+ else:
154
+ raise HTTPException(status_code=400, detail="Unsupported file type.")
155
+
156
+
157
+ # Intelligent Question Answering (Placeholder)
158
+ @app.post("/ask")
159
+ async def ask(file: UploadFile = File(...), question: str = Form(...)):
160
+ file_type = file.filename.split(".")[-1].lower()
161
+ file_path = UPLOAD_DIR / file.filename
162
+ reader = easyocr.Reader(["en"])
163
+
164
+ with open(file_path, "wb") as f:
165
+ shutil.copyfileobj(file.file, f)
166
+
167
+ if file_type in ["docx", "xlsx", "pptx", "pdf", "txt"]:
168
+ text = extract_text(file_path, file_type)
169
+
170
+ elif file_type in ["png", "jpg", "jpeg"]:
171
+ with Image.open(file.file) as image:
172
+ text = reader.readtext(image)
173
+
174
+ else:
175
+ raise HTTPException(status_code=400, detail="Unsupported file type.")
176
+
177
+ if not text:
178
+ raise HTTPException(
179
+ status_code=400,
180
+ detail="The File doesn't contain any text.",
181
+ )
182
+
183
+ else:
184
+ result = question_answering(question=question, context=text)
185
+ return {"answer": result["answer"]}
186
+
187
+
188
+ # Data Visualization Code Generation
189
+ @app.post("/generate-visualization")
190
+ async def visualization(file: UploadFile = File(...), request: str = Form(...)):
191
+ return {"message": "Visualisation is not implemented yet."}
192
+
193
+
194
+ # Text Translation
195
+ @app.post("/translate")
196
+ async def translate(file: UploadFile = File(...), target_language: str = Form(...)):
197
+ file_type = file.filename.split(".")[-1].lower()
198
+ file_path = UPLOAD_DIR / file.filename
199
+ output_filename = f"translated_{file.filename}"
200
+ output_path = PROCESSED_DIR / output_filename
201
+
202
+ with open(file_path, "wb") as f:
203
+ shutil.copyfileobj(file.file, f)
204
+
205
+ try:
206
+ text = extract_text(file_path, file_type)
207
+
208
+ # Auto-detect source language if not provided
209
+
210
+ source_language = detect(text[:1000]) # Check first 1000 chars
211
+ # Convert to M2M100 language codes
212
+ source_language = {
213
+ "en": "en",
214
+ "fr": "fr",
215
+ "es": "es",
216
+ "de": "de",
217
+ "ar": "ar",
218
+ "zh": "zh",
219
+ "ja": "ja",
220
+ "ru": "ru",
221
+ }.get(source_language, source_language)
222
+
223
+ # Validate languages
224
+ supported_languages = tokenizer.lang_code_to_id.keys()
225
+ if source_language not in supported_languages:
226
+ raise HTTPException(400, f"Unsupported source language: {source_language}")
227
+ if target_language not in supported_languages:
228
+ raise HTTPException(400, f"Unsupported target language: {target_language}")
229
+
230
+ tokenizer.src_lang = source_language
231
+ encoded_inputs = tokenizer(text, return_tensors="pt")
232
+ generated_tokens = translation_model.generate(
233
+ **encoded_inputs, forced_bos_token_id=tokenizer.get_lang_id(target_language)
234
+ )
235
+ translated_text = tokenizer.decode(
236
+ generated_tokens[0], skip_special_tokens=True
237
+ )
238
+
239
+ save_file(translated_text, file_path, file_type, output_path)
240
+
241
+ return FileResponse(output_path, filename=output_filename)
242
+
243
+ except Exception as e:
244
+ raise HTTPException(
245
+ status_code=500, detail="Task not supported. Use 'translate to [language]'."
246
+ )
backend/utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pptx import Presentation
2
+ import pdfplumber
3
+ from reportlab.lib.pagesizes import letter
4
+ from reportlab.pdfgen import canvas
5
+ from io import BytesIO
6
+ import docx
7
+ from pathlib import Path
8
+ import openpyxl
9
+
10
+ def extract_text(file_path: Path, file_type: str) -> str:
11
+ text = ""
12
+
13
+ if file_type == "txt":
14
+ with open(file_path, "r", encoding="utf-8") as f:
15
+ text = f.read()
16
+
17
+ elif file_type == "docx":
18
+ doc = docx.Document(file_path)
19
+ text = "\n".join([para.text for para in doc.paragraphs if para.text])
20
+
21
+ elif file_type == "xlsx":
22
+ wb = openpyxl.load_workbook(file_path)
23
+ sheet = wb.active
24
+ for row in sheet.rows:
25
+ for cell in row:
26
+ if cell.value is not None:
27
+ text += str(cell.value) + " "
28
+
29
+ elif file_type == "pptx":
30
+ prs = Presentation(file_path)
31
+ for slide in prs.slides:
32
+ for shape in slide.shapes:
33
+ if shape.has_text_frame:
34
+ for paragraph in shape.text_frame.paragraphs:
35
+ if (clean_text := paragraph.text.strip()):
36
+ text += clean_text + "\n"
37
+
38
+ elif shape.has_table:
39
+ for row in shape.table.rows:
40
+ for cell in row.cells:
41
+ if (cell_text := cell.text.strip()):
42
+ text += cell_text + "\n"
43
+
44
+
45
+ elif file_type == "pdf":
46
+ with pdfplumber.open(file_path) as pdf:
47
+ text = "\n".join(
48
+ page.extract_text()
49
+ for page in pdf.pages
50
+ if page.extract_text()
51
+ )
52
+
53
+ return text.strip()
54
+
55
+ def save_file(text: str, original_path: Path, file_type: str, output_path: Path):
56
+ if file_type == "docx":
57
+ doc = docx.Document()
58
+ doc.add_paragraph(text)
59
+ doc.save(output_path)
60
+
61
+ elif file_type == "xlsx":
62
+ wb = openpyxl.Workbook()
63
+ sheet = wb.active
64
+ text_lines = text.split(
65
+ "\n"
66
+ )
67
+ for i, line in enumerate(text_lines, start=1):
68
+ sheet.cell(row=i, column=1, value=line)
69
+ wb.save(output_path)
70
+
71
+ elif file_type == "pptx":
72
+ prs = Presentation()
73
+ slide_layout = prs.slide_layouts[1]
74
+ slide = prs.slides.add_slide(slide_layout)
75
+ content = slide.shapes.placeholders[1]
76
+ content.text = text
77
+ prs.save(output_path)
78
+
79
+ elif file_type == "pdf":
80
+ with open(output_path, "wb") as f:
81
+ pdf_buffer = BytesIO()
82
+ c = canvas.Canvas(pdf_buffer, pagesize=letter)
83
+ text_lines = text.split("\n")
84
+ y = 750
85
+ for line in text_lines:
86
+ c.drawString(72, y, line)
87
+ y -= 12
88
+ if y < 50:
89
+ c.showPage()
90
+ y = 750
91
+ c.save()
92
+ f.write(pdf_buffer.getvalue())
93
+
94
+ else:
95
+ with open(output_path, "w", encoding="utf-8") as f:
96
+ f.write(text)