File size: 4,739 Bytes
0ad094b
 
 
6e5d4e6
e725104
 
6e5d4e6
0ad094b
0948bff
0ad094b
 
490fdc6
 
 
6e5d4e6
 
 
0948bff
6e5d4e6
490fdc6
0ad094b
490fdc6
 
 
6e5d4e6
 
f9901d1
490fdc6
6e5d4e6
490fdc6
 
 
 
 
 
 
 
6e5d4e6
 
490fdc6
4da1971
e725104
490fdc6
0948bff
490fdc6
 
 
4da1971
 
 
490fdc6
9d495ed
490fdc6
4da1971
5f20637
f9901d1
 
9d495ed
 
 
8442332
490fdc6
9d495ed
490fdc6
0948bff
490fdc6
 
 
6e5d4e6
490fdc6
 
 
 
 
 
 
 
c3dacfc
490fdc6
 
 
 
c3dacfc
490fdc6
08eeadf
 
 
490fdc6
08eeadf
490fdc6
0ad094b
6e5d4e6
2a3cef3
6e5d4e6
 
 
 
 
 
0ad094b
490fdc6
 
 
0ad094b
 
6e5d4e6
490fdc6
 
 
 
 
 
6e5d4e6
490fdc6
6e5d4e6
 
490fdc6
6e5d4e6
 
 
 
490fdc6
6e5d4e6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.responses import JSONResponse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from openai import OpenAI

print("Version ---- DeepSeek Only")
app = FastAPI()

# -----------------------------
# Request schema
# -----------------------------
class ConflictDetectionRequest(BaseModel):
    Req1: str
    Req2: str
    model_choice: str   # "GPT-4", "DeepSeek-Reasoner", "Fanar"
    prompt_type: str    # "zero-shot" or "few-shot"
    api_key: str = None # required only if model_choice == "GPT-4"

# -----------------------------
# Prompt builder
# -----------------------------
def build_prompt(req1, req2, prompt_type="zero-shot"):
    if prompt_type == "zero-shot":
        return f"Do the following sentences contradict each other, Answer with only 'yes' or 'no', no explanation. \n 1.{req1} 2.{req2}"
    elif prompt_type == "few-shot":
        examples = (
            "Example 1:\n"
            "Req1: The system shall allow password reset.\n"
            "Req2: The system shall not allow password reset.\n"
            "Answer: yes\n\n"
            "Example 2:\n"
            "Req1: The system shall support Arabic language.\n"
            "Req2: The system shall support English language.\n"
            "Answer: no\n\n"
        )
        return examples + f"Now answer: Do the following sentences contradict each other? 1.{req1} 2.{req2}"
    else:
        return f"Do the following sentences contradict each other, yes or no: 1.{req1} 2.{req2}"

# -----------------------------
# Startup: load DeepSeek once
# -----------------------------
@app.on_event("startup")
def load_models():
    print("Loading DeepSeek model into memory...")
    model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
    app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(model_name)
    app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token
    
    app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        # max_tokens=10,  # Limit to very few tokens
        temperature=0.1,    # Low temperature for more deterministic output
        do_sample=False,    # Disable sampling for consistent results
        torch_dtype="auto", 
        device_map="auto", 
        offload_folder="offload"
    )


# -----------------------------
# Model handlers
# -----------------------------
def run_gpt4(req1, req2, prompt_type, api_key):
    client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
    prompt = build_prompt(req1, req2, prompt_type)
    completion = client.chat.completions.create(
        model="openai/gpt-4",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.7,
        max_tokens=512
    )
    return completion.choices[0].message.content.strip()


def run_deepseek(req1, req2, prompt_type):
    tokenizer = app.state.deepseek_tokenizer
    model = app.state.deepseek_model
    prompt = build_prompt(req1, req2, prompt_type)
    inputs = tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
    outputs = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_new_tokens=256,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def run_fanar(req1, req2, prompt_type):
    client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
    prompt = build_prompt(req1, req2, prompt_type)
    response = client.chat.completions.create(
        model="Fanar",
        messages=[{"role": "user", "content": prompt}]
    )
    return response.choices[0].message.content.strip()

# -----------------------------
# API route
# -----------------------------
@app.post("/predict")
def predict(request: ConflictDetectionRequest):
    try:
        if request.model_choice == "GPT-4":
            if not request.api_key:
                return JSONResponse({"error": "API key required for GPT-4"}, status_code=400)
            answer = run_gpt4(request.Req1, request.Req2, request.prompt_type, request.api_key)

        elif request.model_choice == "DeepSeek-Reasoner":
            answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)

        elif request.model_choice == "Fanar":
            answer = run_fanar(request.Req1, request.Req2, request.prompt_type)

        else:
            return JSONResponse({"error": "Invalid model_choice"}, status_code=400)

        return JSONResponse({"resp": answer, "statusText": "OK", "statusCode": 0}, status_code=200)

    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)