TymaaHammouda commited on
Commit
6e5d4e6
·
verified ·
1 Parent(s): f4ada89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -42
app.py CHANGED
@@ -1,58 +1,119 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from fastapi.responses import JSONResponse
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model
7
 
8
- print("Version ---- 1")
9
  app = FastAPI()
10
 
11
- # Load model and tokenizer from Hugging Face
12
- model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
 
 
 
 
 
 
 
13
 
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_name,
18
- dtype=torch.bfloat16,
19
- device_map="auto",
20
- offload_folder="offload" # folder for disk offload
21
- )
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- class ConflictDetectionRequest(BaseModel):
25
- Req1: str
26
- Req2: str
 
 
 
 
 
 
 
 
 
27
 
28
-
 
 
 
 
 
 
 
29
 
 
 
 
30
  @app.post("/predict")
31
  def predict(request: ConflictDetectionRequest):
32
- Req1 = request.Req1
33
- Req2 = request.Req2
34
-
35
- question = f"Do the following sentences contradict each other, answer with just yes or no: 1.{Req1} 2.{Req2}"
36
- inputs = tokenizer([question], return_tensors="pt").to(model.device)
37
-
38
- # Generate response
39
- outputs = model.generate(
40
- input_ids=inputs.input_ids,
41
- attention_mask=inputs.attention_mask,
42
- max_new_tokens=512,
43
- do_sample=True,
44
- temperature=0.7,
45
- top_p=0.9
46
- )
47
-
48
- # Decode and print response
49
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
- # print(response.split("### Response:")[-1].strip())
51
-
52
- content = {"resp": response.split("</think>")[1].strip(), "statusText": "OK","statusCode" : 0}
53
-
54
- return JSONResponse(
55
- content=content,
56
- media_type="application/json",
57
- status_code=200,
58
- )
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from fastapi.responses import JSONResponse
4
+ import os
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from openai import OpenAI
8
 
9
+ print("Version ---- 2")
10
  app = FastAPI()
11
 
12
+ # -----------------------------
13
+ # Request schema
14
+ # -----------------------------
15
+ class ConflictDetectionRequest(BaseModel):
16
+ Req1: str
17
+ Req2: str
18
+ model_choice: str # "GPT-4", "DeepSeek-Reasoner", "LLaMA-3.1-8B-Instruct", "Fanar"
19
+ prompt_type: str # "zero-shot" or "few-shot"
20
+ api_key: str = None # required only if model_choice == "GPT-4"
21
 
22
+ # -----------------------------
23
+ # Prompt builder
24
+ # -----------------------------
25
+ def build_prompt(req1, req2, prompt_type="zero-shot"):
26
+ if prompt_type == "zero-shot":
27
+ return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
28
+ elif prompt_type == "few-shot":
29
+ # Example few-shot style (you can expand with more examples)
30
+ examples = (
31
+ "Example 1:\n"
32
+ "Req1: The system shall allow password reset.\n"
33
+ "Req2: The system shall not allow password reset.\n"
34
+ "Answer: yes\n\n"
35
+ "Example 2:\n"
36
+ "Req1: The system shall support Arabic language.\n"
37
+ "Req2: The system shall support English language.\n"
38
+ "Answer: no\n\n"
39
+ )
40
+ return examples + f"Now answer: Do the following sentences contradict each other? 1.{req1} 2.{req2}"
41
+ else:
42
+ return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
43
 
44
+ # -----------------------------
45
+ # Model handlers
46
+ # -----------------------------
47
+ def run_gpt4(req1, req2, prompt_type, api_key):
48
+ client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
49
+ prompt = build_prompt(req1, req2, prompt_type)
50
+ completion = client.chat.completions.create(
51
+ model="openai/gpt-4",
52
+ messages=[{"role": "user", "content": prompt}],
53
+ temperature=0.7,
54
+ max_tokens=512
55
+ )
56
+ return completion.choices[0].message.content.strip()
57
 
58
+ def run_deepseek(req1, req2, prompt_type):
59
+ model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
60
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ model_name,
63
+ dtype=torch.bfloat16,
64
+ device_map="auto"
65
+ )
66
+ prompt = build_prompt(req1, req2, prompt_type)
67
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
68
+ outputs = model.generate(inputs.input_ids, max_new_tokens=256)
69
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
70
 
71
+ def run_llama(req1, req2, prompt_type):
72
+ model_name = "meta-llama/Llama-3.1-8B-Instruct"
73
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
74
+ model = AutoModelForCausalLM.from_pretrained(
75
+ model_name,
76
+ dtype=torch.bfloat16,
77
+ device_map="auto"
78
+ )
79
+ prompt = build_prompt(req1, req2, prompt_type)
80
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
81
+ outputs = model.generate(inputs.input_ids, max_new_tokens=256)
82
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
83
 
84
+ def run_fanar(req1, req2, prompt_type):
85
+ client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API_KEY"))
86
+ prompt = build_prompt(req1, req2, prompt_type)
87
+ response = client.chat.completions.create(
88
+ model="Fanar",
89
+ messages=[{"role": "user", "content": prompt}]
90
+ )
91
+ return response.choices[0].message.content.strip()
92
 
93
+ # -----------------------------
94
+ # API route
95
+ # -----------------------------
96
  @app.post("/predict")
97
  def predict(request: ConflictDetectionRequest):
98
+ try:
99
+ if request.model_choice == "GPT-4":
100
+ if not request.api_key:
101
+ return JSONResponse({"error": "API key required for GPT-4"}, status_code=400)
102
+ answer = run_gpt4(request.Req1, request.Req2, request.prompt_type, request.api_key)
103
+
104
+ elif request.model_choice == "DeepSeek-Reasoner":
105
+ answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
106
+
107
+ elif request.model_choice == "LLaMA-3.1-8B-Instruct":
108
+ answer = run_llama(request.Req1, request.Req2, request.prompt_type)
109
+
110
+ elif request.model_choice == "Fanar":
111
+ answer = run_fanar(request.Req1, request.Req2, request.prompt_type)
112
+
113
+ else:
114
+ return JSONResponse({"error": "Invalid model_choice"}, status_code=400)
115
+
116
+ return JSONResponse({"resp": answer, "statusText": "OK", "statusCode": 0}, status_code=200)
117
+
118
+ except Exception as e:
119
+ return JSONResponse({"error": str(e)}, status_code=500)