TymaaHammouda's picture
Update app.py
5f20637 verified
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.responses import JSONResponse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from openai import OpenAI
print("Version ---- DeepSeek Only")
app = FastAPI()
# -----------------------------
# Request schema
# -----------------------------
class ConflictDetectionRequest(BaseModel):
Req1: str
Req2: str
model_choice: str # "GPT-4", "DeepSeek-Reasoner", "Fanar"
prompt_type: str # "zero-shot" or "few-shot"
api_key: str = None # required only if model_choice == "GPT-4"
# -----------------------------
# Prompt builder
# -----------------------------
def build_prompt(req1, req2, prompt_type="zero-shot"):
if prompt_type == "zero-shot":
return f"Do the following sentences contradict each other, Answer with only 'yes' or 'no', no explanation. \n 1.{req1} 2.{req2}"
elif prompt_type == "few-shot":
examples = (
"Example 1:\n"
"Req1: The system shall allow password reset.\n"
"Req2: The system shall not allow password reset.\n"
"Answer: yes\n\n"
"Example 2:\n"
"Req1: The system shall support Arabic language.\n"
"Req2: The system shall support English language.\n"
"Answer: no\n\n"
)
return examples + f"Now answer: Do the following sentences contradict each other? 1.{req1} 2.{req2}"
else:
return f"Do the following sentences contradict each other, yes or no: 1.{req1} 2.{req2}"
# -----------------------------
# Startup: load DeepSeek once
# -----------------------------
@app.on_event("startup")
def load_models():
print("Loading DeepSeek model into memory...")
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(model_name)
app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token
app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
model_name,
# max_tokens=10, # Limit to very few tokens
temperature=0.1, # Low temperature for more deterministic output
do_sample=False, # Disable sampling for consistent results
torch_dtype="auto",
device_map="auto",
offload_folder="offload"
)
# -----------------------------
# Model handlers
# -----------------------------
def run_gpt4(req1, req2, prompt_type, api_key):
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
prompt = build_prompt(req1, req2, prompt_type)
completion = client.chat.completions.create(
model="openai/gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=512
)
return completion.choices[0].message.content.strip()
def run_deepseek(req1, req2, prompt_type):
tokenizer = app.state.deepseek_tokenizer
model = app.state.deepseek_model
prompt = build_prompt(req1, req2, prompt_type)
inputs = tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def run_fanar(req1, req2, prompt_type):
client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
prompt = build_prompt(req1, req2, prompt_type)
response = client.chat.completions.create(
model="Fanar",
messages=[{"role": "user", "content": prompt}]
)
return response.choices[0].message.content.strip()
# -----------------------------
# API route
# -----------------------------
@app.post("/predict")
def predict(request: ConflictDetectionRequest):
try:
if request.model_choice == "GPT-4":
if not request.api_key:
return JSONResponse({"error": "API key required for GPT-4"}, status_code=400)
answer = run_gpt4(request.Req1, request.Req2, request.prompt_type, request.api_key)
elif request.model_choice == "DeepSeek-Reasoner":
answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
elif request.model_choice == "Fanar":
answer = run_fanar(request.Req1, request.Req2, request.prompt_type)
else:
return JSONResponse({"error": "Invalid model_choice"}, status_code=400)
return JSONResponse({"resp": answer, "statusText": "OK", "statusCode": 0}, status_code=200)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)