NitinBot001 commited on
Commit
d91bbbb
·
verified ·
1 Parent(s): 67eb214

Upload 7 files

Browse files
Files changed (7) hide show
  1. app/__init__.py +0 -0
  2. app/auth.py +74 -0
  3. app/database.py +21 -0
  4. app/main.py +49 -0
  5. app/models.py +31 -0
  6. app/proxy_handler.py +180 -0
  7. app/schemas.py +70 -0
app/__init__.py ADDED
File without changes
app/auth.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime, timedelta, timezone
3
+ from typing import Optional
4
+
5
+ from fastapi import Depends, HTTPException, status
6
+ from fastapi.security import OAuth2PasswordBearer
7
+ from jose import JWTError, jwt
8
+ from passlib.context import CryptContext
9
+ from cryptography.fernet import Fernet
10
+ from sqlalchemy.orm import Session
11
+
12
+ from .database import get_db
13
+ from . import models
14
+
15
+ SECRET_KEY = os.getenv("SECRET_KEY", "CHANGE_ME_super_secret_key_32bytes!")
16
+ ALGORITHM = "HS256"
17
+ ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 60 * 24 * 7))
18
+
19
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
20
+
21
+ _raw_fernet_key = os.getenv("FERNET_KEY", "")
22
+ if not _raw_fernet_key:
23
+ _raw_fernet_key = Fernet.generate_key().decode()
24
+ print(f"[WARN] FERNET_KEY not set. Generated key (add to .env): {_raw_fernet_key}")
25
+
26
+ fernet = Fernet(_raw_fernet_key.encode() if isinstance(_raw_fernet_key, str) else _raw_fernet_key)
27
+
28
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
29
+
30
+
31
+ def verify_password(plain: str, hashed: str) -> bool:
32
+ return pwd_context.verify(plain, hashed)
33
+
34
+
35
+ def hash_password(plain: str) -> str:
36
+ return pwd_context.hash(plain)
37
+
38
+
39
+ def encrypt_api_key(api_key: str) -> str:
40
+ return fernet.encrypt(api_key.encode()).decode()
41
+
42
+
43
+ def decrypt_api_key(encrypted: str) -> str:
44
+ return fernet.decrypt(encrypted.encode()).decode()
45
+
46
+
47
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
48
+ to_encode = data.copy()
49
+ expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
50
+ to_encode["exp"] = expire
51
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
52
+
53
+
54
+ def get_current_user(
55
+ token: str = Depends(oauth2_scheme),
56
+ db: Session = Depends(get_db),
57
+ ) -> models.User:
58
+ exc = HTTPException(
59
+ status_code=status.HTTP_401_UNAUTHORIZED,
60
+ detail="Could not validate credentials",
61
+ headers={"WWW-Authenticate": "Bearer"},
62
+ )
63
+ try:
64
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
65
+ username: str = payload.get("sub")
66
+ if not username:
67
+ raise exc
68
+ except JWTError:
69
+ raise exc
70
+
71
+ user = db.query(models.User).filter(models.User.username == username).first()
72
+ if not user:
73
+ raise exc
74
+ return user
app/database.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine
2
+ from sqlalchemy.orm import declarative_base, sessionmaker
3
+ import os
4
+
5
+ DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./proxy.db")
6
+
7
+ engine = create_engine(
8
+ DATABASE_URL,
9
+ connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {},
10
+ )
11
+
12
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
13
+ Base = declarative_base()
14
+
15
+
16
+ def get_db():
17
+ db = SessionLocal()
18
+ try:
19
+ yield db
20
+ finally:
21
+ db.close()
app/main.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import asynccontextmanager
3
+
4
+ from fastapi import FastAPI
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.staticfiles import StaticFiles
7
+ from fastapi.responses import FileResponse
8
+
9
+ from .database import engine, Base
10
+ from .routers import auth_router, proxy_config_router, proxy_endpoint_router
11
+
12
+
13
+ @asynccontextmanager
14
+ async def lifespan(app: FastAPI):
15
+ Base.metadata.create_all(bind=engine)
16
+ yield
17
+
18
+
19
+ app = FastAPI(
20
+ title="Anthropic ↔ OpenAI Proxy",
21
+ description="Converts Anthropic API calls to OpenAI-compatible backend calls via LiteLLM",
22
+ version="1.0.0",
23
+ lifespan=lifespan,
24
+ )
25
+
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ app.include_router(auth_router.router)
35
+ app.include_router(proxy_config_router.router)
36
+ app.include_router(proxy_endpoint_router.router)
37
+
38
+ _static_dir = os.path.join(os.path.dirname(__file__), "static")
39
+ app.mount("/static", StaticFiles(directory=_static_dir), name="static")
40
+
41
+
42
+ @app.get("/", include_in_schema=False)
43
+ def serve_ui():
44
+ return FileResponse(os.path.join(_static_dir, "index.html"))
45
+
46
+
47
+ @app.get("/health")
48
+ def health():
49
+ return {"status": "ok"}
app/models.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, Text
2
+ from sqlalchemy.orm import relationship
3
+ from datetime import datetime, timezone
4
+ from .database import Base
5
+
6
+
7
+ class User(Base):
8
+ __tablename__ = "users"
9
+
10
+ id = Column(Integer, primary_key=True, index=True)
11
+ username = Column(String(64), unique=True, index=True, nullable=False)
12
+ email = Column(String(255), unique=True, index=True, nullable=False)
13
+ hashed_password = Column(String(255), nullable=False)
14
+ created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
15
+
16
+ proxies = relationship("ProxyConfig", back_populates="owner", cascade="all, delete-orphan")
17
+
18
+
19
+ class ProxyConfig(Base):
20
+ __tablename__ = "proxy_configs"
21
+
22
+ id = Column(Integer, primary_key=True, index=True)
23
+ user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
24
+ proxy_token = Column(String(64), unique=True, index=True, nullable=False)
25
+ name = Column(String(128), nullable=False)
26
+ openai_base_url = Column(String(512), nullable=False)
27
+ encrypted_api_key = Column(Text, nullable=False)
28
+ model_mapping = Column(Text, default="{}")
29
+ created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
30
+
31
+ owner = relationship("User", back_populates="proxies")
app/proxy_handler.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+ from typing import AsyncIterator
4
+
5
+ import litellm
6
+ from fastapi.responses import StreamingResponse
7
+
8
+ litellm.set_verbose = False
9
+
10
+
11
+ def anthropic_to_openai_messages(messages: list, system: str | None) -> list:
12
+ openai_msgs = []
13
+
14
+ if system:
15
+ openai_msgs.append({"role": "system", "content": system})
16
+
17
+ for msg in messages:
18
+ role = msg["role"]
19
+ content = msg["content"]
20
+
21
+ if isinstance(content, str):
22
+ openai_msgs.append({"role": role, "content": content})
23
+
24
+ elif isinstance(content, list):
25
+ parts = []
26
+ for block in content:
27
+ btype = block.get("type")
28
+ if btype == "text":
29
+ parts.append({"type": "text", "text": block["text"]})
30
+ elif btype == "image":
31
+ src = block.get("source", {})
32
+ if src.get("type") == "base64":
33
+ url = f"data:{src['media_type']};base64,{src['data']}"
34
+ else:
35
+ url = src.get("url", "")
36
+ parts.append({"type": "image_url", "image_url": {"url": url}})
37
+ elif btype in ("tool_use", "tool_result"):
38
+ parts.append({"type": "text", "text": json.dumps(block)})
39
+
40
+ openai_msgs.append({
41
+ "role": role,
42
+ "content": parts if len(parts) > 1 else (parts[0]["text"] if parts else ""),
43
+ })
44
+
45
+ return openai_msgs
46
+
47
+
48
+ _STOP_REASON_MAP = {
49
+ "stop": "end_turn",
50
+ "length": "max_tokens",
51
+ "content_filter": "stop_sequence",
52
+ "tool_calls": "tool_use",
53
+ }
54
+
55
+
56
+ def openai_response_to_anthropic(oai_resp, original_model: str) -> dict:
57
+ choice = oai_resp.choices[0]
58
+ usage = oai_resp.usage
59
+
60
+ return {
61
+ "id": f"msg_{uuid.uuid4().hex[:24]}",
62
+ "type": "message",
63
+ "role": "assistant",
64
+ "content": [{"type": "text", "text": choice.message.content or ""}],
65
+ "model": original_model,
66
+ "stop_reason": _STOP_REASON_MAP.get(choice.finish_reason or "stop", "end_turn"),
67
+ "stop_sequence": None,
68
+ "usage": {
69
+ "input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0,
70
+ "output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0,
71
+ },
72
+ }
73
+
74
+
75
+ async def stream_anthropic_sse(params: dict, original_model: str) -> AsyncIterator[str]:
76
+ msg_id = f"msg_{uuid.uuid4().hex[:24]}"
77
+
78
+ def _sse(event: str, data: dict) -> str:
79
+ return f"event: {event}\ndata: {json.dumps(data)}\n\n"
80
+
81
+ yield _sse("message_start", {
82
+ "type": "message_start",
83
+ "message": {
84
+ "id": msg_id, "type": "message", "role": "assistant",
85
+ "content": [], "model": original_model,
86
+ "stop_reason": None, "stop_sequence": None,
87
+ "usage": {"input_tokens": 0, "output_tokens": 0},
88
+ },
89
+ })
90
+
91
+ yield _sse("content_block_start", {
92
+ "type": "content_block_start", "index": 0,
93
+ "content_block": {"type": "text", "text": ""},
94
+ })
95
+
96
+ yield _sse("ping", {"type": "ping"})
97
+
98
+ output_tokens = 0
99
+ stop_reason = "end_turn"
100
+ input_tokens = 0
101
+
102
+ try:
103
+ response = await litellm.acompletion(**params)
104
+ async for chunk in response:
105
+ delta_content = None
106
+ if chunk.choices:
107
+ delta_content = chunk.choices[0].delta.content
108
+ finish = chunk.choices[0].finish_reason
109
+ if finish:
110
+ stop_reason = _STOP_REASON_MAP.get(finish, "end_turn")
111
+
112
+ if delta_content:
113
+ output_tokens += 1
114
+ yield _sse("content_block_delta", {
115
+ "type": "content_block_delta", "index": 0,
116
+ "delta": {"type": "text_delta", "text": delta_content},
117
+ })
118
+
119
+ if hasattr(chunk, "usage") and chunk.usage:
120
+ input_tokens = getattr(chunk.usage, "prompt_tokens", input_tokens)
121
+ output_tokens = getattr(chunk.usage, "completion_tokens", output_tokens)
122
+
123
+ except Exception as exc:
124
+ yield _sse("error", {"type": "error", "error": {"type": "api_error", "message": str(exc)}})
125
+ return
126
+
127
+ yield _sse("content_block_stop", {"type": "content_block_stop", "index": 0})
128
+ yield _sse("message_delta", {
129
+ "type": "message_delta",
130
+ "delta": {"stop_reason": stop_reason, "stop_sequence": None},
131
+ "usage": {"output_tokens": output_tokens},
132
+ })
133
+ yield _sse("message_stop", {"type": "message_stop"})
134
+
135
+
136
+ async def handle_messages_request(body: dict, proxy_config):
137
+ from .auth import decrypt_api_key
138
+
139
+ anthropic_model = body.get("model", "claude-3-opus-20240229")
140
+ messages = body.get("messages", [])
141
+ system = body.get("system")
142
+ max_tokens = body.get("max_tokens", 1024)
143
+ temperature = body.get("temperature", 1.0)
144
+ stream = body.get("stream", False)
145
+ top_p = body.get("top_p")
146
+ stop_seqs = body.get("stop_sequences")
147
+
148
+ try:
149
+ model_mapping = json.loads(proxy_config.model_mapping or "{}")
150
+ except Exception:
151
+ model_mapping = {}
152
+
153
+ openai_model = model_mapping.get(anthropic_model, anthropic_model)
154
+ openai_msgs = anthropic_to_openai_messages(messages, system)
155
+ api_key = decrypt_api_key(proxy_config.encrypted_api_key)
156
+
157
+ params: dict = {
158
+ "model": f"openai/{openai_model}",
159
+ "messages": openai_msgs,
160
+ "max_tokens": max_tokens,
161
+ "temperature": temperature,
162
+ "stream": stream,
163
+ "api_key": api_key,
164
+ "api_base": proxy_config.openai_base_url.rstrip("/"),
165
+ }
166
+
167
+ if top_p is not None:
168
+ params["top_p"] = top_p
169
+ if stop_seqs:
170
+ params["stop"] = stop_seqs
171
+
172
+ if stream:
173
+ return StreamingResponse(
174
+ stream_anthropic_sse(params, anthropic_model),
175
+ media_type="text/event-stream",
176
+ headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
177
+ )
178
+
179
+ response = await litellm.acompletion(**params)
180
+ return openai_response_to_anthropic(response, anthropic_model)
app/schemas.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr, field_validator
2
+ from typing import Optional, Dict
3
+ from datetime import datetime
4
+
5
+
6
+ class UserCreate(BaseModel):
7
+ username: str
8
+ email: EmailStr
9
+ password: str
10
+
11
+ @field_validator("username")
12
+ @classmethod
13
+ def username_alphanumeric(cls, v: str) -> str:
14
+ if not v.replace("_", "").replace("-", "").isalnum():
15
+ raise ValueError("Username must be alphanumeric (underscores/hyphens allowed)")
16
+ if len(v) < 3:
17
+ raise ValueError("Username must be at least 3 characters")
18
+ return v
19
+
20
+ @field_validator("password")
21
+ @classmethod
22
+ def password_strength(cls, v: str) -> str:
23
+ if len(v) < 6:
24
+ raise ValueError("Password must be at least 6 characters")
25
+ return v
26
+
27
+
28
+ class UserOut(BaseModel):
29
+ id: int
30
+ username: str
31
+ email: str
32
+ created_at: datetime
33
+
34
+ model_config = {"from_attributes": True}
35
+
36
+
37
+ class Token(BaseModel):
38
+ access_token: str
39
+ token_type: str = "bearer"
40
+ user: UserOut
41
+
42
+
43
+ class LoginRequest(BaseModel):
44
+ username: str
45
+ password: str
46
+
47
+
48
+ class ProxyCreate(BaseModel):
49
+ name: str
50
+ openai_base_url: str
51
+ openai_api_key: str
52
+ model_mapping: Optional[Dict[str, str]] = {}
53
+
54
+
55
+ class ProxyUpdate(BaseModel):
56
+ name: Optional[str] = None
57
+ openai_base_url: Optional[str] = None
58
+ openai_api_key: Optional[str] = None
59
+ model_mapping: Optional[Dict[str, str]] = None
60
+
61
+
62
+ class ProxyOut(BaseModel):
63
+ id: int
64
+ name: str
65
+ proxy_token: str
66
+ openai_base_url: str
67
+ model_mapping: str
68
+ created_at: datetime
69
+
70
+ model_config = {"from_attributes": True}