Signlink / gemma_loader.py
gaurannggg7's picture
Update gemma_loader.py
d6f76b6 verified
import os
import requests
import torch
class Gemma3nEdge:
"""
Lightweight Gemma3nEdge wrapper adapted for cloud container execution.
Bypasses massive local tensor weight downloads and maps to external GPU endpoints.
"""
def __init__(self, model_dir: str = "models/gemma3n_E2B"):
# Explicitly declare a CPU device profile to maintain compatibility with other scripts
self.device = torch.device("cpu")
# Map your custom aliases to public Hugging Face model target IDs
hf_alias = {
"gemma3n_E2B": "google/gemma-2b-it",
"gemma-3n-E4B": "google/gemma-2b-it"
}
self.model_id = hf_alias.get(model_dir, "google/gemma-2b-it")
self.api_url = f"https://api-inference.huggingface.co/models/{self.model_id}"
self.headers = {"Authorization": f"Bearer {os.environ.get('HF_TOKEN')}"}
def generate_gloss(self, transcript: str, max_tokens: int = 100) -> str:
prompt = f"English: {transcript}\nASL Gloss:"
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_tokens,
"return_full_text": False
}
}
try:
response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=15)
output = response.json()
# Safely extract text from alternative API response payload signatures
if isinstance(output, list) and len(output) > 0 and "generated_text" in output[0]:
return output[0]["generated_text"].strip()
elif isinstance(output, dict) and "generated_text" in output:
return output["generated_text"].strip()
else:
print(f"⚠️ API Response Parsing Warning. Raw output: {output}")
return transcript.upper()
except Exception as e:
print(f"⚠️ Inference API failed or timed out: {e}. Falling back to uppercase text string split.")
return transcript.upper()