g0th commited on
Commit
c0be42f
Β·
verified Β·
1 Parent(s): cc76e51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -52
app.py CHANGED
@@ -1,36 +1,27 @@
1
  import os
2
  import json
3
- from PIL import Image
4
  import torch
5
  import gradio as gr
6
- from transformers import (
7
- BlipProcessor,
8
- AutoTokenizer,
9
- Llama4ForConditionalGeneration,
10
- )
11
  from ppt_parser import transfer_to_structure
12
 
13
- # βœ… Load Hugging Face token and model ID
14
- hf_token = os.getenv("HF_TOKEN")
15
- model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
16
 
17
- # βœ… Load processor, tokenizer, and model
18
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
19
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
20
-
21
- model = Llama4ForConditionalGeneration.from_pretrained(
22
  model_id,
23
  token=hf_token,
24
- attn_implementation="flex_attention",
25
- device_map="auto",
26
- torch_dtype=torch.bfloat16,
27
  )
 
28
 
29
- # βœ… Global state
30
  extracted_text = ""
31
- image_paths = []
32
 
33
- # βœ… Extract all text from parsed PPTX JSON
34
  def extract_text_from_pptx_json(parsed_json: dict) -> str:
35
  text = ""
36
  for slide in parsed_json.values():
@@ -47,57 +38,36 @@ def extract_text_from_pptx_json(parsed_json: dict) -> str:
47
  text += para.get("text", "") + "\n"
48
  return text.strip()
49
 
50
- # βœ… Handle uploaded .pptx files
51
  def handle_pptx_upload(pptx_file):
52
- global extracted_text, image_paths
53
  tmp_path = pptx_file.name
54
- parsed_json_str, image_paths = transfer_to_structure(tmp_path, "images")
55
  parsed_json = json.loads(parsed_json_str)
56
  extracted_text = extract_text_from_pptx_json(parsed_json)
57
  return extracted_text or "No readable text found in slides."
58
 
59
- # βœ… Ask Llama 4 Scout using image + text
60
  def ask_llama(question):
61
- global extracted_text, image_paths
62
-
63
- if not extracted_text and not image_paths:
64
- return "Please upload and extract a PPTX file first."
65
-
66
- # Use the first slide image
67
- image = Image.open(image_paths[0]).convert("RGB")
68
- vision_inputs = processor(images=image, return_tensors="pt").to(model.device)
69
-
70
- # Prepare text prompt
71
- prompt = f"<|user|>\n{extracted_text}\n\nQuestion: {question}<|end|>\n<|assistant|>\n"
72
- text_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
73
-
74
- with torch.no_grad():
75
- output = model.generate(
76
- input_ids=text_inputs["input_ids"],
77
- pixel_values=vision_inputs["pixel_values"],
78
- max_new_tokens=256,
79
- )
80
 
81
- # Decode response
82
- response = tokenizer.decode(
83
- output[0][text_inputs["input_ids"].shape[-1]:],
84
- skip_special_tokens=True
85
- )
86
- return response.strip()
87
 
88
  # βœ… Gradio UI
89
  with gr.Blocks() as demo:
90
- gr.Markdown("## 🧠 Llama-4-Scout Multimodal Study Assistant")
91
 
92
  pptx_input = gr.File(label="πŸ“‚ Upload PPTX File", file_types=[".pptx"])
93
- extract_btn = gr.Button("πŸ“œ Extract Text + Slides")
94
 
95
  extracted_output = gr.Textbox(label="πŸ“„ Slide Text", lines=10, interactive=False)
96
  extract_btn.click(handle_pptx_upload, inputs=[pptx_input], outputs=[extracted_output])
97
 
98
  question = gr.Textbox(label="❓ Ask a Question")
99
- ask_btn = gr.Button("πŸ’¬ Ask Scout")
100
- ai_answer = gr.Textbox(label="πŸ€– Llama Answer", lines=6)
101
 
102
  ask_btn.click(ask_llama, inputs=[question], outputs=[ai_answer])
103
 
 
1
  import os
2
  import json
 
3
  import torch
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
 
6
  from ppt_parser import transfer_to_structure
7
 
8
+ # βœ… Hugging Face token (optional if public + unauthenticated)
9
+ hf_token = os.getenv("HF_TOKEN", None)
10
+ model_id = "meta-llama/Llama-3.1-8B-Instruct"
11
 
12
+ # βœ… Load model + tokenizer
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
14
+ model = AutoModelForCausalLM.from_pretrained(
 
15
  model_id,
16
  token=hf_token,
17
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
18
+ device_map="auto"
 
19
  )
20
+ llama_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
21
 
22
+ # βœ… Global storage
23
  extracted_text = ""
 
24
 
 
25
  def extract_text_from_pptx_json(parsed_json: dict) -> str:
26
  text = ""
27
  for slide in parsed_json.values():
 
38
  text += para.get("text", "") + "\n"
39
  return text.strip()
40
 
 
41
  def handle_pptx_upload(pptx_file):
42
+ global extracted_text
43
  tmp_path = pptx_file.name
44
+ parsed_json_str, _ = transfer_to_structure(tmp_path, "images")
45
  parsed_json = json.loads(parsed_json_str)
46
  extracted_text = extract_text_from_pptx_json(parsed_json)
47
  return extracted_text or "No readable text found in slides."
48
 
 
49
  def ask_llama(question):
50
+ global extracted_text
51
+ if not extracted_text:
52
+ return "Please upload a PPTX file first."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ prompt = f"<|user|>\nContext:\n{extracted_text}\n\nQuestion: {question}<|end|>\n<|assistant|>\n"
55
+ response = llama_pipe(prompt)[0]["generated_text"]
56
+ return response.replace(prompt, "").strip()
 
 
 
57
 
58
  # βœ… Gradio UI
59
  with gr.Blocks() as demo:
60
+ gr.Markdown("## 🧠 Study Assistant with LLaMA 3.1 8B")
61
 
62
  pptx_input = gr.File(label="πŸ“‚ Upload PPTX File", file_types=[".pptx"])
63
+ extract_btn = gr.Button("πŸ“œ Extract Slide Text")
64
 
65
  extracted_output = gr.Textbox(label="πŸ“„ Slide Text", lines=10, interactive=False)
66
  extract_btn.click(handle_pptx_upload, inputs=[pptx_input], outputs=[extracted_output])
67
 
68
  question = gr.Textbox(label="❓ Ask a Question")
69
+ ask_btn = gr.Button("πŸ’¬ Ask LLaMA")
70
+ ai_answer = gr.Textbox(label="πŸ€– LLaMA Answer", lines=6)
71
 
72
  ask_btn.click(ask_llama, inputs=[question], outputs=[ai_answer])
73