Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,12 +10,12 @@ from transformers import (
|
|
| 10 |
)
|
| 11 |
from ppt_parser import transfer_to_structure
|
| 12 |
|
| 13 |
-
# ✅ Load Hugging Face token
|
| 14 |
hf_token = os.getenv("HF_TOKEN")
|
| 15 |
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
| 16 |
|
| 17 |
-
# ✅ Load
|
| 18 |
-
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
|
| 20 |
|
| 21 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
@@ -30,6 +30,7 @@ model = Llama4ForConditionalGeneration.from_pretrained(
|
|
| 30 |
extracted_text = ""
|
| 31 |
image_paths = []
|
| 32 |
|
|
|
|
| 33 |
def extract_text_from_pptx_json(parsed_json: dict) -> str:
|
| 34 |
text = ""
|
| 35 |
for slide in parsed_json.values():
|
|
@@ -46,6 +47,7 @@ def extract_text_from_pptx_json(parsed_json: dict) -> str:
|
|
| 46 |
text += para.get("text", "") + "\n"
|
| 47 |
return text.strip()
|
| 48 |
|
|
|
|
| 49 |
def handle_pptx_upload(pptx_file):
|
| 50 |
global extracted_text, image_paths
|
| 51 |
tmp_path = pptx_file.name
|
|
@@ -54,16 +56,18 @@ def handle_pptx_upload(pptx_file):
|
|
| 54 |
extracted_text = extract_text_from_pptx_json(parsed_json)
|
| 55 |
return extracted_text or "No readable text found in slides."
|
| 56 |
|
|
|
|
| 57 |
def ask_llama(question):
|
| 58 |
global extracted_text, image_paths
|
| 59 |
|
| 60 |
if not extracted_text and not image_paths:
|
| 61 |
return "Please upload and extract a PPTX file first."
|
| 62 |
|
| 63 |
-
#
|
| 64 |
image = Image.open(image_paths[0]).convert("RGB")
|
| 65 |
-
vision_inputs =
|
| 66 |
|
|
|
|
| 67 |
prompt = f"<|user|>\n{extracted_text}\n\nQuestion: {question}<|end|>\n<|assistant|>\n"
|
| 68 |
text_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 69 |
|
|
@@ -74,7 +78,11 @@ def ask_llama(question):
|
|
| 74 |
max_new_tokens=256,
|
| 75 |
)
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
return response.strip()
|
| 79 |
|
| 80 |
# ✅ Gradio UI
|
|
|
|
| 10 |
)
|
| 11 |
from ppt_parser import transfer_to_structure
|
| 12 |
|
| 13 |
+
# ✅ Load Hugging Face token and model ID
|
| 14 |
hf_token = os.getenv("HF_TOKEN")
|
| 15 |
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
| 16 |
|
| 17 |
+
# ✅ Load processor, tokenizer, and model
|
| 18 |
+
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
|
| 20 |
|
| 21 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
|
|
|
| 30 |
extracted_text = ""
|
| 31 |
image_paths = []
|
| 32 |
|
| 33 |
+
# ✅ Extract all text from parsed PPTX JSON
|
| 34 |
def extract_text_from_pptx_json(parsed_json: dict) -> str:
|
| 35 |
text = ""
|
| 36 |
for slide in parsed_json.values():
|
|
|
|
| 47 |
text += para.get("text", "") + "\n"
|
| 48 |
return text.strip()
|
| 49 |
|
| 50 |
+
# ✅ Handle uploaded .pptx files
|
| 51 |
def handle_pptx_upload(pptx_file):
|
| 52 |
global extracted_text, image_paths
|
| 53 |
tmp_path = pptx_file.name
|
|
|
|
| 56 |
extracted_text = extract_text_from_pptx_json(parsed_json)
|
| 57 |
return extracted_text or "No readable text found in slides."
|
| 58 |
|
| 59 |
+
# ✅ Ask Llama 4 Scout using image + text
|
| 60 |
def ask_llama(question):
|
| 61 |
global extracted_text, image_paths
|
| 62 |
|
| 63 |
if not extracted_text and not image_paths:
|
| 64 |
return "Please upload and extract a PPTX file first."
|
| 65 |
|
| 66 |
+
# Use the first slide image
|
| 67 |
image = Image.open(image_paths[0]).convert("RGB")
|
| 68 |
+
vision_inputs = processor(images=image, return_tensors="pt").to(model.device)
|
| 69 |
|
| 70 |
+
# Prepare text prompt
|
| 71 |
prompt = f"<|user|>\n{extracted_text}\n\nQuestion: {question}<|end|>\n<|assistant|>\n"
|
| 72 |
text_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 73 |
|
|
|
|
| 78 |
max_new_tokens=256,
|
| 79 |
)
|
| 80 |
|
| 81 |
+
# Decode response
|
| 82 |
+
response = tokenizer.decode(
|
| 83 |
+
output[0][text_inputs["input_ids"].shape[-1]:],
|
| 84 |
+
skip_special_tokens=True
|
| 85 |
+
)
|
| 86 |
return response.strip()
|
| 87 |
|
| 88 |
# ✅ Gradio UI
|