WallTD-v.1 / main.py
Feriel080's picture
Update main.py
1a2afd5 verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
import shutil
from pathlib import Path
from transformers import (
pipeline,
M2M100Tokenizer,
M2M100ForConditionalGeneration,
BartTokenizer,
BlipProcessor, BlipForConditionalGeneration,
AutoModelForCausalLM, AutoTokenizer
)
from utils import extract_text, save_file, verify_summary, ensure_complete_sentences
from langdetect import detect, DetectorFactory
from langcodes import Language
import torch
from PIL import Image
import os
import pytesseract
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import hashlib
import re
from concurrent.futures import ThreadPoolExecutor
app = FastAPI()
# Hugging Face models
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
summary_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
interpretation_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
translation_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
translation_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")
question_answering = pipeline("question-answering", model="bert-large-uncased-whole-word-masking-finetuned-squad")
visual_question_answering = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
DetectorFactory.seed = 0
# Directory to store uploaded and processed files
UPLOAD_DIR = Path("uploads")
PROCESSED_DIR = Path("processed")
UPLOAD_DIR.mkdir(exist_ok=True)
PROCESSED_DIR.mkdir(exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Hugging Face Token
API_TOKEN = os.environ.get("HF_TOKEN")
if not API_TOKEN:
raise ValueError("HUGGINGFACE_API_TOKEN environment variable not set.")
code_generation_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct", token=API_TOKEN)
code_generation_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct", token=API_TOKEN)
code_generation_tokenizer.pad_token_id = code_generation_tokenizer.eos_token_id
code_generation_generator = pipeline("text-generation", model="Qwen/Qwen2.5-Coder-1.5B-Instruct", tokenizer=code_generation_tokenizer, device=-1)
# Uploading Static files
app.mount("/assets", StaticFiles(directory="frontend/assets", html=True), name="assets")
app.mount("/images", StaticFiles(directory="frontend/images", html=True), name="images")
app.mount("/processed", StaticFiles(directory="processed"), name="processed")
@app.get("/")
async def serve_frontend():
return FileResponse("frontend/index.html")
# List processed files
@app.get("/processed_files")
async def list_processed_files():
files = [f.name for f in PROCESSED_DIR.iterdir() if f.is_file()]
return {"files": files}
# Download a processed file
@app.get("/download/{filename}")
async def download_file(filename: str):
file_path = PROCESSED_DIR / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file_path, filename=filename)
def split_text(text, max_words=1000):
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
current_chunk.append(word)
current_length += 1
if current_length >= max_words:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_length = 0
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
# Document & Image Analysis (Summarization & Interpretation)
@app.post("/docsum_imginter")
async def docsum_imginter(file: UploadFile = File(...), task: str = Form(...)):
file_type = file.filename.split(".")[-1].lower()
file_path = UPLOAD_DIR / file.filename
output_filename = f"summarized_{file.filename}"
output_path = PROCESSED_DIR / output_filename
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
if task.lower() == "summarize":
text = extract_text(file_path, file_type)
if not text or not text.strip():
raise HTTPException(400, "No text found in document")
if len(text.strip().split()) < 150:
raise HTTPException(400, "WallD thinks the file is too small for summarization - minimum 150 words",)
text = text.encode("ascii", "ignore").decode("ascii")
chunks = split_text(text)
summaries = []
prompt = (
"Generate a concise, factual summary covering ALL key sections of the text. "
"Include: main objectives, critical details, and outcomes if mentioned. "
"Never include: contact information, website links, or promotional content. "
"\n"
"Text to summarize:\n{chunk}"
)
for chunk in chunks:
if not chunk.strip():
continue
word_count = len(chunk.split())
max_length = min(max(int(word_count * 0.4), 150),512)
summary_result = summarizer(
prompt.format(chunk=chunk),
max_length=max_length,
min_length=max(150, int(max_length * 0.6)),
do_sample=False,
truncation=True,
repetition_penalty=1.5,
no_repeat_ngram_size=3,
num_beams=4,
length_penalty=1.0,
)
if summary_result:
raw_summary = summary_result[0]["summary_text"]
verified = verify_summary(raw_summary, chunk)
if verified:
complete = ensure_complete_sentences(verified)
summaries.append(complete)
if not summaries:
raise HTTPException(500, "Summary verification failed - no valid content extracted")
full_summary = "\n".join(filter(None, summaries))
if len(summaries) > 1:
full_summary = summarizer(
f"Combine these partial summaries into one coherent paragraph:\n{full_summary}",
max_length=512,
)[0]["summary_text"]
if not full_summary.strip():
sentences = [s.strip() for s in text.split(".") if s.strip()]
full_summary = (". ".join(sentences[:3]) + "." if sentences else text[:500])
save_file(full_summary, file_type, output_path)
return FileResponse(output_path, filename=output_filename)
elif task.lower() == "interpret":
try:
with Image.open(file_path) as image:
if image.mode != "RGB":
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt")
if inputs is None or "pixel_values" not in inputs:
raise ValueError("Image processing failed: No valid inputs generated.")
outputs = interpretation_model.generate(**inputs, repetition_penalty=1.2)
if outputs is None:
raise ValueError("Model generation failed: No outputs produced.")
caption = processor.decode(outputs[0], skip_special_tokens=True)
return {"caption": caption if caption else "No caption generated"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
finally:
if file_path.exists():
file_path.unlink()
# Intelligent Question Answering
def is_visual_question(question: str) -> bool:
visual_keywords = [
"color", "describe", "what do you see", "how many",
"is there", "are there", "what is in", "can you see"
]
question = question.lower()
return any(keyword in question for keyword in visual_keywords)
@app.post("/ask")
async def ask(file: UploadFile = File(...), question: str = Form(...)):
try:
file_type = file.filename.split(".")[-1].lower()
file_path = UPLOAD_DIR / file.filename
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
if file_type in ["docx", "xlsx", "pptx", "pdf", "txt"]:
text = extract_text(file_path, file_type)
elif file_type in ["png", "jpg", "jpeg", 'webp']:
with Image.open(file_path) as image:
if image.mode != 'RGB':
image = image.convert('RGB')
if is_visual_question(question):
vqa_result = visual_question_answering(image, question, top_k=1)[0]
return {"answer": vqa_result["answer"]}
text = pytesseract.image_to_string(image)
else:
raise HTTPException(status_code=400, detail="Unsupported file type.")
if not text:
raise HTTPException(status_code=400,detail="The File doesn't contain any text.",)
result = question_answering(question=question, context=text)
return {"answer": result["answer"]}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request. {str(e)}")
finally:
if file_path.exists():
file_path.unlink()
# Data Visualization Code Generation
@app.post("/generate-visualization")
async def visualization(file: UploadFile = File(...), request: str = Form(...)):
file_path = UPLOAD_DIR / file.filename
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
try:
df = pd.read_excel(file_path)
if df.empty:
raise ValueError("Excel file is empty.")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error reading Excel file: {str(e)}")
input_text = f"""
Given the DataFrame 'df' with columns {', '.join(df.columns)} and preview:
{df.head().to_string()}
Write Python code to: create {request}
- Use ONLY 'df = pd.read_excel({file.filename})' (no external data loading like pd.read_csv or creating a new DataFrame).
- Use pandas (pd), matplotlib.pyplot (plt), or seaborn (sns).
- Include axis labels and a title.
- Output ONLY executable Python code. Do NOT include triple quotes, prose, Markdown, or text like 'Hint', 'Solution', or 'Here is the code'.
"""
try:
generated = code_generation_generator(input_text, max_new_tokens=500, num_return_sequences=1)
generated_code = generated[0]["generated_text"].replace(input_text, "").strip()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error querying model: {str(e)}")
if not generated_code.strip():
raise HTTPException(status_code=500, detail="No code generated by the AI model.")
code_block_pattern = r"```python\n(.*?)(\n```|\Z)"
matches = list(re.finditer(code_block_pattern, generated_code, re.DOTALL))
if matches:
raw_code_block = matches[0].group(1).strip()
executable_code = raw_code_block
else:
raise HTTPException(status_code=500, detail="No valid Python code block found in generated output.")
executable_code = "\n".join(
line.strip() for line in executable_code.splitlines()
if line.strip() and
not any(kw in line for kw in ["pd.read_csv", "pd.read_excel", "plt.show", "df ="])
).strip()
display_code = "\n".join(
line.strip() for line in raw_code_block.splitlines()
if line.strip()
).strip()
if not executable_code:
raise HTTPException(status_code=500, detail="Generated code was invalid (e.g., included data loading, df redefinition, or was empty).")
plot_hash = hashlib.md5(f"{file.filename}_{request}".encode()).hexdigest()[:8]
plot_filename = f"plot_{plot_hash}.png"
plot_path = PROCESSED_DIR / plot_filename
try:
exec_globals = {"pd": pd, "plt": plt, "sns": sns, "df": df}
exec(executable_code, exec_globals)
plt.savefig(plot_path, bbox_inches="tight")
plt.close()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error executing code: {str(e)}")
finally:
if file_path.exists():
file_path.unlink()
if not plot_path.exists():
raise HTTPException(status_code=500, detail="Plot file was not created.")
# Return the file response
return {"code": display_code, "image_path": plot_path}
# Text Translation
def split_tran_text_trans(text, max_chunk_size=800):
chunks = []
current_chunk = []
current_length = 0
paragraphs = text.split("\n\n")
for para in paragraphs:
para = para.strip()
if not para:
continue
words = len(para.split())
if current_length + words <= max_chunk_size:
current_chunk.append(para)
current_length += words
else:
if current_chunk:
chunks.append("\n\n".join(current_chunk))
current_chunk = [para]
current_length = words
if current_chunk:
chunks.append("\n\n".join(current_chunk))
return chunks
@app.post("/translate")
async def translate_document(file: UploadFile = File(...), target_language: str = Form(...)):
file_type = file.filename.split(".")[-1].lower()
file_path = UPLOAD_DIR / file.filename
output_filename = f"translated_to_{target_language}_{file.filename}"
output_path = PROCESSED_DIR / output_filename
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
try:
text = extract_text(file_path, file_type)
source_language = detect(text[:1000])
tr_language = Language.find(target_language).language
source_language = {
"en": "en", "fr": "fr", "es": "es", "de": "de",
"ar": "ar", "zh": "zh", "ja": "ja", "ru": "ru",
}.get(source_language, source_language)
supported_languages = translation_tokenizer.lang_code_to_id.keys()
if source_language not in supported_languages:
raise HTTPException(400, f"Unsupported source language: {Language.get(source_language).display_name()}")
if tr_language not in supported_languages:
raise HTTPException(400, f"Unsupported target language: {target_language}")
chunks = split_tran_text_trans(text)
translated_chunks = []
translation_tokenizer.src_lang = source_language
def translate_chunk(chunk):
try:
inputs = translation_tokenizer(chunk, return_tensors="pt", truncation=True, max_length=800)
generated_tokens = translation_model.generate(
**inputs,
forced_bos_token_id=translation_tokenizer.get_lang_id(tr_language),
max_length=1000
)
return translation_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
except Exception as e:
print(f"Error translating chunk: {str(e)}")
return chunk
with ThreadPoolExecutor() as executor:
translated_chunks = list(executor.map(translate_chunk, chunks))
translated_text = "\n\n".join(translated_chunks)
save_file(translated_text, file_type, output_path)
return FileResponse(output_path, filename=output_filename)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
if file_path.exists():
file_path.unlink()