zakihassan04 commited on
Commit
3bb5602
·
verified ·
1 Parent(s): 2a3a922

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -49
app.py CHANGED
@@ -14,66 +14,73 @@ with open("data/gpt2_ready_filtered.jsonl", "r", encoding="utf-8") as f:
14
  data = [json.loads(line) for line in f]
15
  texts = [item["text"] for item in data]
16
 
17
- # Load model
18
- model_name = "nurfarah57/SQ-MT5"
19
- tokenizer = MT5Tokenizer.from_pretrained(model_name)
20
- model = MT5ForConditionalGeneration.from_pretrained(model_name)
21
- model.eval()
 
 
 
 
22
 
23
- # Load sentence embedder
24
- embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
25
- embeddings = embedder.encode(texts, convert_to_tensor=True)
 
 
26
 
27
- # FastAPI app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  app = FastAPI(
29
  title="Somali QA API",
30
- description="Su’aal weydii oo hel jawaab laga raadshay dataset-ka ama laga sameeyay model.",
31
  version="1.0"
32
  )
33
 
34
- # Input schema
35
  class QuestionRequest(BaseModel):
36
  question: str
37
 
38
- # Extract question/answer from dataset line
39
- def extract_qa(text):
40
- parts = text.split("\nJawaab:")
41
- if len(parts) == 2:
42
- return parts[0].replace("Su'aal:", "").strip(), parts[1].strip()
43
- return None, None
44
-
45
- # Match dataset semantically
46
- def find_semantic_match(question, threshold=0.90):
47
- user_emb = embedder.encode(question, convert_to_tensor=True)
48
- hits = util.semantic_search(user_emb, embeddings, top_k=1)
49
- if hits and hits[0][0]["score"] >= threshold:
50
- idx = hits[0][0]["corpus_id"]
51
- _, jawaab = extract_qa(texts[idx])
52
- return jawaab
53
- return None
54
-
55
- # Fallback generation
56
- def generate_with_mt5(question):
57
- prompt = f"su'aal: {question}"
58
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
59
- with torch.no_grad():
60
- outputs = model.generate(inputs["input_ids"], max_length=128)
61
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
62
 
63
- # API endpoint
64
  @app.post("/qa")
65
- def answer_question(req: QuestionRequest):
66
  if not req.question.strip():
67
  raise HTTPException(status_code=400, detail="Su’aal lama helin.")
68
-
69
- match = find_semantic_match(req.question)
70
- if match:
71
- return {"answer": match, "source": "dataset"}
72
-
73
- generated = generate_with_mt5(req.question)
74
- return {"answer": generated, "source": "model"}
75
-
76
- # Root
77
- @app.get("/")
78
- def root():
79
- return {"message": "✅ Somali QA API is running!"}
 
14
  data = [json.loads(line) for line in f]
15
  texts = [item["text"] for item in data]
16
 
17
+ # SomaliQA class
18
+ class SomaliQA:
19
+ def __init__(self, dataset_texts):
20
+ self.texts = dataset_texts
21
+ self.embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
22
+ self.embeddings = self.embedder.encode(self.texts, convert_to_tensor=True)
23
+ self.tokenizer = MT5Tokenizer.from_pretrained("nurfarah57/SQ-MT5")
24
+ self.model = MT5ForConditionalGeneration.from_pretrained("nurfarah57/SQ-MT5")
25
+ self.model.eval()
26
 
27
+ def extract_qa(self, text):
28
+ parts = text.split("\nJawaab:")
29
+ if len(parts) == 2:
30
+ return parts[0].replace("Su'aal:", "").strip(), parts[1].strip()
31
+ return None, None
32
 
33
+ def clean_text(self, text):
34
+ return text.strip().lower().rstrip("?").replace("’", "'").replace(" ", " ")
35
+
36
+ def generate_with_mt5(self, question):
37
+ input_text = f"su'aal: {question}"
38
+ inputs = self.tokenizer(input_text, return_tensors="pt", padding=True)
39
+ with torch.no_grad():
40
+ outputs = self.model.generate(**inputs, max_length=128)
41
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+
43
+ def answer(self, user_question):
44
+ if not user_question.strip().endswith("?"):
45
+ user_question += "?"
46
+ user_clean = self.clean_text(user_question)
47
+
48
+ # Exact match
49
+ for text in self.texts:
50
+ su_aal, jawaab = self.extract_qa(text)
51
+ if su_aal and user_clean == self.clean_text(su_aal):
52
+ return {"answer": jawaab, "source": "exact"}
53
+
54
+ # Semantic match
55
+ user_emb = self.embedder.encode(user_clean, convert_to_tensor=True)
56
+ hits = util.semantic_search(user_emb, self.embeddings, top_k=1)
57
+ if hits and len(hits[0]) > 0:
58
+ idx = hits[0][0]['corpus_id']
59
+ su_aal, jawaab = self.extract_qa(self.texts[idx])
60
+ return {"answer": jawaab, "source": "semantic"}
61
+
62
+ # Fallback to generation
63
+ return {"answer": self.generate_with_mt5(user_question), "source": "generated"}
64
+
65
+ # Init model
66
+ qa_system = SomaliQA(texts)
67
+
68
+ # FastAPI
69
  app = FastAPI(
70
  title="Somali QA API",
71
+ description="Weydii su’aal oo hel jawaab sax ah laga helay dataset ama MT5 generation.",
72
  version="1.0"
73
  )
74
 
 
75
  class QuestionRequest(BaseModel):
76
  question: str
77
 
78
+ @app.get("/")
79
+ def root():
80
+ return {"message": "✅ Somali QA API is running!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
82
  @app.post("/qa")
83
+ def get_answer(req: QuestionRequest):
84
  if not req.question.strip():
85
  raise HTTPException(status_code=400, detail="Su’aal lama helin.")
86
+ return qa_system.answer(req.question)