Mustafa-albakkar commited on
Commit
a84eae4
·
verified ·
1 Parent(s): f767249

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -19
app.py CHANGED
@@ -1,36 +1,54 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch, os
4
- from fastapi import FastAPI, Request
5
- import uvicorn
6
 
 
7
  MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
9
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype=torch.float16)
 
 
 
 
 
10
 
 
11
  PROMPT_TEMPLATE = """You are a GAIA final-answer extractor.
12
- Extract only what follows "Final Answer:" from the text, or infer if missing.
 
 
13
  Text:
14
  {text}
 
 
15
  """
16
 
17
- def extract_final_answer(raw):
 
 
 
18
  prompt = PROMPT_TEMPLATE.format(text=raw)
19
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
20
- out = model.generate(**inputs, max_new_tokens=64, temperature=0.2)
21
- return tokenizer.decode(out[0], skip_special_tokens=True).splitlines()[-1]
22
-
23
- app = FastAPI()
24
-
25
- @app.post("/api/predict")
26
- async def predict(request: Request):
27
- data = await request.json()
28
- text = data.get("data", [""])[0]
29
- return {"data": [extract_final_answer(text)]}
30
 
31
- iface = gr.Interface(fn=extract_final_answer, inputs="text", outputs="text", title="Final Answer Agent")
 
 
 
 
 
 
 
32
 
 
 
33
  if __name__ == "__main__":
34
- import threading
35
- threading.Thread(target=lambda: iface.launch(server_name="0.0.0.0", server_port=7863)).start()
36
- uvicorn.run(app, host="0.0.0.0", port=7863)
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch, os
 
 
4
 
5
+ # ✅ 1. تحديد الموديل مع إعدادات آمنة
6
  MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ MODEL_ID,
10
+ device_map="auto",
11
+ torch_dtype=torch.float16
12
+ )
13
+ model.eval()
14
 
15
+ # ✅ 2. قالب التلقين
16
  PROMPT_TEMPLATE = """You are a GAIA final-answer extractor.
17
+ Your job is to extract only what comes after 'Final Answer:' in the text below.
18
+ If no 'Final Answer:' is present, infer the most likely short final answer (one line only).
19
+
20
  Text:
21
  {text}
22
+
23
+ Return only the final answer (no extra words).
24
  """
25
 
26
+ # ✅ 3. دالة استخراج الإجابة النهائية
27
+ def extract_final_answer(raw: str) -> str:
28
+ if not raw or not raw.strip():
29
+ return "No input provided."
30
  prompt = PROMPT_TEMPLATE.format(text=raw)
31
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
32
+ with torch.no_grad():
33
+ out = model.generate(**inputs, max_new_tokens=64, temperature=0.2)
34
+ decoded = tokenizer.decode(out[0], skip_special_tokens=True)
35
+
36
+ # نحاول التقاط آخر سطر كناتج نهائي
37
+ lines = [l.strip() for l in decoded.splitlines() if l.strip()]
38
+ if not lines:
39
+ return decoded.strip()
40
+ return lines[-1]
 
41
 
42
+ # 4. إعداد واجهة Gradio متوافقة مع Client.predict() في الوكيل الأساسي
43
+ iface = gr.Interface(
44
+ fn=extract_final_answer,
45
+ inputs=gr.Textbox(label="Input Text", lines=6, placeholder="Paste the reasoning text from main agent..."),
46
+ outputs=gr.Textbox(label="Extracted Final Answer", lines=1),
47
+ title="Final Answer Extractor",
48
+ description="Extracts the GAIA Final Answer from reasoning text."
49
+ )
50
 
51
+ # ✅ 5. تشغيل Gradio فقط (بدون FastAPI + Uvicorn)
52
+ # الوكيل الأساسي يستخدم gradio_client للتواصل، لذلك لا نحتاج FastAPI هنا.
53
  if __name__ == "__main__":
54
+ iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7863)))