g0th commited on
Commit
cc76e51
·
verified ·
1 Parent(s): 0312d09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -10,12 +10,12 @@ from transformers import (
10
  )
11
  from ppt_parser import transfer_to_structure
12
 
13
- # ✅ Load Hugging Face token
14
  hf_token = os.getenv("HF_TOKEN")
15
  model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
16
 
17
- # ✅ Load image processor, tokenizer, and model manually
18
- image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
19
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
20
 
21
  model = Llama4ForConditionalGeneration.from_pretrained(
@@ -30,6 +30,7 @@ model = Llama4ForConditionalGeneration.from_pretrained(
30
  extracted_text = ""
31
  image_paths = []
32
 
 
33
  def extract_text_from_pptx_json(parsed_json: dict) -> str:
34
  text = ""
35
  for slide in parsed_json.values():
@@ -46,6 +47,7 @@ def extract_text_from_pptx_json(parsed_json: dict) -> str:
46
  text += para.get("text", "") + "\n"
47
  return text.strip()
48
 
 
49
  def handle_pptx_upload(pptx_file):
50
  global extracted_text, image_paths
51
  tmp_path = pptx_file.name
@@ -54,16 +56,18 @@ def handle_pptx_upload(pptx_file):
54
  extracted_text = extract_text_from_pptx_json(parsed_json)
55
  return extracted_text or "No readable text found in slides."
56
 
 
57
  def ask_llama(question):
58
  global extracted_text, image_paths
59
 
60
  if not extracted_text and not image_paths:
61
  return "Please upload and extract a PPTX file first."
62
 
63
- # Use the first image only (you can expand to multiple with batching)
64
  image = Image.open(image_paths[0]).convert("RGB")
65
- vision_inputs = image_processor(images=image, return_tensors="pt").to(model.device)
66
 
 
67
  prompt = f"<|user|>\n{extracted_text}\n\nQuestion: {question}<|end|>\n<|assistant|>\n"
68
  text_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
69
 
@@ -74,7 +78,11 @@ def ask_llama(question):
74
  max_new_tokens=256,
75
  )
76
 
77
- response = tokenizer.decode(output[0][text_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
 
 
 
 
78
  return response.strip()
79
 
80
  # ✅ Gradio UI
 
10
  )
11
  from ppt_parser import transfer_to_structure
12
 
13
+ # ✅ Load Hugging Face token and model ID
14
  hf_token = os.getenv("HF_TOKEN")
15
  model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
16
 
17
+ # ✅ Load processor, tokenizer, and model
18
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
19
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
20
 
21
  model = Llama4ForConditionalGeneration.from_pretrained(
 
30
  extracted_text = ""
31
  image_paths = []
32
 
33
+ # ✅ Extract all text from parsed PPTX JSON
34
  def extract_text_from_pptx_json(parsed_json: dict) -> str:
35
  text = ""
36
  for slide in parsed_json.values():
 
47
  text += para.get("text", "") + "\n"
48
  return text.strip()
49
 
50
+ # ✅ Handle uploaded .pptx files
51
  def handle_pptx_upload(pptx_file):
52
  global extracted_text, image_paths
53
  tmp_path = pptx_file.name
 
56
  extracted_text = extract_text_from_pptx_json(parsed_json)
57
  return extracted_text or "No readable text found in slides."
58
 
59
+ # ✅ Ask Llama 4 Scout using image + text
60
  def ask_llama(question):
61
  global extracted_text, image_paths
62
 
63
  if not extracted_text and not image_paths:
64
  return "Please upload and extract a PPTX file first."
65
 
66
+ # Use the first slide image
67
  image = Image.open(image_paths[0]).convert("RGB")
68
+ vision_inputs = processor(images=image, return_tensors="pt").to(model.device)
69
 
70
+ # Prepare text prompt
71
  prompt = f"<|user|>\n{extracted_text}\n\nQuestion: {question}<|end|>\n<|assistant|>\n"
72
  text_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
73
 
 
78
  max_new_tokens=256,
79
  )
80
 
81
+ # Decode response
82
+ response = tokenizer.decode(
83
+ output[0][text_inputs["input_ids"].shape[-1]:],
84
+ skip_special_tokens=True
85
+ )
86
  return response.strip()
87
 
88
  # ✅ Gradio UI